提交 82ddda11 authored 作者: Ziye Fan's avatar Ziye Fan

Use local_dimshuffle_lift.transform() instead of global toposort optimizer's optimize()

上级 f0522bff
......@@ -166,14 +166,15 @@ class test_dimshuffle_lift(unittest.TestCase):
"(<TensorType(float64, matrix)>, "
"DimShuffle{x,x}(TensorConstant{84}))))]")
self.assertTrue(str(g) == init_str_g)
dimshuffle_lift.optimize(g)
new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{0,x}(<TensorType(float64, vector)>), "
"DimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(DimShuffle{1,0}"
"(<TensorType(float64, matrix)>), "
"DimShuffle{x,x}(TensorConstant{84})))]")
self.assertTrue(str(g) == opt_str_g)
self.assertTrue(str(new_g) == opt_str_g)
def test_add_canonizer_problem0():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论