提交 d30b8fca authored 作者: Saizheng Zhang's avatar Saizheng Zhang

solve confilcts in test_opt.py, test_subtensor_inc_subtensor

上级 c26a0201
...@@ -1894,7 +1894,8 @@ def local_track_shape_i(node): ...@@ -1894,7 +1894,8 @@ def local_track_shape_i(node):
replacement = shape_feature.scheduled[node] replacement = shape_feature.scheduled[node]
return [shape_feature.shape_of[replacement][node.op.i]] return [shape_feature.shape_of[replacement][node.op.i]]
#TH-2801 opt: subtensor(incsubtensor)
# TH-2801 opt: subtensor(incsubtensor)
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
...@@ -1905,11 +1906,13 @@ def local_subtensor_inc_subtensor(node): ...@@ -1905,11 +1906,13 @@ def local_subtensor_inc_subtensor(node):
return return
if not x.owner.op.set_instead_of_inc: if not x.owner.op.set_instead_of_inc:
return return
if x.owner.inputs[2] == node.inputs[1] and tuple(x.owner.op.idx_list) == tuple(node.op.idx_list):
if x.owner.inputs[2:] == node.inputs[1:] and tuple(x.owner.op.idx_list) == tuple(node.op.idx_list):
return [x.owner.inputs[1]] return [x.owner.inputs[1]]
else: else:
return return
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
......
...@@ -6768,7 +6768,6 @@ if __name__ == '__main__': ...@@ -6768,7 +6768,6 @@ if __name__ == '__main__':
t.setUp() t.setUp()
# t.test_perform() # t.test_perform()
t.test_infer_shape() t.test_infer_shape()
""" """
# unittest.main() # unittest.main()
test_fusion().tes_memory_leak() test_fusion().tes_memory_leak()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论