提交 930603b0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix strict type equality in aesara.tensor.subtensor_opt

上级 7278da49
...@@ -11,6 +11,7 @@ from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, local_opti ...@@ -11,6 +11,7 @@ from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, local_opti
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
ARange, ARange,
MakeVector,
Rebroadcast, Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
TensorFromScalar, TensorFromScalar,
...@@ -67,7 +68,7 @@ from aesara.tensor.var import TensorConstant, TensorVariable ...@@ -67,7 +68,7 @@ from aesara.tensor.var import TensorConstant, TensorVariable
def register_useless(lopt, *tags, **kwargs): def register_useless(lopt, *tags, **kwargs):
if type(lopt) == str: if isinstance(lopt, str):
def register(inner_lopt): def register(inner_lopt):
return register_useless(inner_lopt, lopt, *tags, **kwargs) return register_useless(inner_lopt, lopt, *tags, **kwargs)
...@@ -556,7 +557,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): ...@@ -556,7 +557,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
# The idx is a Scalar, ie a Type. This means the actual index # The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1] # is contained in node.inputs[1]
dim_index = node.inputs[node_inputs_idx] dim_index = node.inputs[node_inputs_idx]
if type(dim_index) == aes.ScalarConstant: if isinstance(dim_index, aes.ScalarConstant):
dim_index = dim_index.value dim_index = dim_index.value
if dim_index in [0, -1] and node.inputs[0].broadcastable[dim]: if dim_index in [0, -1] and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim) remove_dim.append(dim)
...@@ -707,8 +708,8 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -707,8 +708,8 @@ def local_subtensor_make_vector(fgraph, node):
""" """
x = node.inputs[0] x = node.inputs[0]
if not x.owner or x.owner.op != make_vector: if not x.owner or not isinstance(x.owner.op, MakeVector):
return return False
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论