提交 ecf0d14a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Correctly set tracked Op on local_track_shape_i

上级 5967e2cb
...@@ -1887,17 +1887,22 @@ def local_shape_to_shape_i(fgraph, node): ...@@ -1887,17 +1887,22 @@ def local_shape_to_shape_i(fgraph, node):
# TODO: Not sure what type of node we are expecting here # TODO: Not sure what type of node we are expecting here
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@local_optimizer(None) @local_optimizer([Shape_i])
def local_track_shape_i(fgraph, node): def local_track_shape_i(fgraph, node):
if not isinstance(node.op, Shape_i):
return False
try: try:
shape_feature = fgraph.shape_feature shape_feature = fgraph.shape_feature
except AttributeError: except AttributeError:
return return False
if node in shape_feature.scheduled:
if node not in shape_feature.scheduled:
return False
# Don't unschedule node as it could be reinserted in the # Don't unschedule node as it could be reinserted in the
# fgraph as we don't change it in the shapefeature internal # fgraph as we don't change it in the shapefeature internal
# structure. # structure.
assert isinstance(node.op, Shape_i)
replacement = shape_feature.scheduled[node] replacement = shape_feature.scheduled[node]
return [shape_feature.shape_of[replacement][node.op.i]] return [shape_feature.shape_of[replacement][node.op.i]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论