提交 82ddda11 authored 作者: Ziye Fan's avatar Ziye Fan

Use local_dimshuffle_lift.transform() instead of global toposort optimizer's optimize()

上级 f0522bff
...@@ -166,14 +166,15 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -166,14 +166,15 @@ class test_dimshuffle_lift(unittest.TestCase):
"(<TensorType(float64, matrix)>, " "(<TensorType(float64, matrix)>, "
"DimShuffle{x,x}(TensorConstant{84}))))]") "DimShuffle{x,x}(TensorConstant{84}))))]")
self.assertTrue(str(g) == init_str_g) self.assertTrue(str(g) == init_str_g)
dimshuffle_lift.optimize(g) new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}" opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{0,x}(<TensorType(float64, vector)>), " "(DimShuffle{0,x}(<TensorType(float64, vector)>), "
"DimShuffle{x,x}(TensorConstant{42})), " "DimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(DimShuffle{1,0}" "Elemwise{add,no_inplace}(DimShuffle{1,0}"
"(<TensorType(float64, matrix)>), " "(<TensorType(float64, matrix)>), "
"DimShuffle{x,x}(TensorConstant{84})))]") "DimShuffle{x,x}(TensorConstant{84})))]")
self.assertTrue(str(g) == opt_str_g) self.assertTrue(str(new_g) == opt_str_g)
def test_add_canonizer_problem0(): def test_add_canonizer_problem0():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论