提交 f0522bff authored 作者: Ziye Fan's avatar Ziye Fan

add test for new dimshuffle_lift

上级 2ed8808a
...@@ -153,6 +153,28 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -153,6 +153,28 @@ class test_dimshuffle_lift(unittest.TestCase):
self.assertTrue(str(g) in (opt_str_g_inplace, opt_str_g_noinplace), self.assertTrue(str(g) in (opt_str_g_inplace, opt_str_g_noinplace),
str(g)) str(g))
def test_recursive_lift(self):
v = T.vector(dtype="float64")
m = T.matrix(dtype="float64")
out = ((v + 42) * (m + 84)).T
g = FunctionGraph([v, m], [out])
init_str_g = ("[DimShuffle{1,0}(Elemwise{mul,no_inplace}"
"(DimShuffle{x,0}(Elemwise{add,no_inplace}"
"(<TensorType(float64, vector)>, "
"DimShuffle{x}(TensorConstant{42}))), "
"Elemwise{add,no_inplace}"
"(<TensorType(float64, matrix)>, "
"DimShuffle{x,x}(TensorConstant{84}))))]")
self.assertTrue(str(g) == init_str_g)
dimshuffle_lift.optimize(g)
opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{0,x}(<TensorType(float64, vector)>), "
"DimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(DimShuffle{1,0}"
"(<TensorType(float64, matrix)>), "
"DimShuffle{x,x}(TensorConstant{84})))]")
self.assertTrue(str(g) == opt_str_g)
def test_add_canonizer_problem0(): def test_add_canonizer_problem0():
n_segments = 10 n_segments = 10
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论