提交 930603b0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix strict type equality in aesara.tensor.subtensor_opt

上级 7278da49
......@@ -11,6 +11,7 @@ from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, local_opti
from aesara.tensor.basic import (
Alloc,
ARange,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
......@@ -67,7 +68,7 @@ from aesara.tensor.var import TensorConstant, TensorVariable
def register_useless(lopt, *tags, **kwargs):
if type(lopt) == str:
if isinstance(lopt, str):
def register(inner_lopt):
return register_useless(inner_lopt, lopt, *tags, **kwargs)
......@@ -556,7 +557,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
# The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1]
dim_index = node.inputs[node_inputs_idx]
if type(dim_index) == aes.ScalarConstant:
if isinstance(dim_index, aes.ScalarConstant):
dim_index = dim_index.value
if dim_index in [0, -1] and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim)
......@@ -707,8 +708,8 @@ def local_subtensor_make_vector(fgraph, node):
"""
x = node.inputs[0]
if not x.owner or x.owner.op != make_vector:
return
if not x.owner or not isinstance(x.owner.op, MakeVector):
return False
if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论