提交 ef8fc7bb authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Accept not-inplace dimshuffles in test

上级 e33c0f75
...@@ -124,13 +124,27 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -124,13 +124,27 @@ class test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs([False] * 1, [False] * 2, [False] * 3) x, y, z = inputs([False] * 1, [False] * 2, [False] * 3)
e = x + y + z e = x + y + z
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
self.assertTrue(str(g) == ("[Elemwise{add,no_inplace}("
"InplaceDimShuffle{x,0,1}(Elemwise{add,no_inplace}" # It does not really matter if the DimShuffles are inplace
"(InplaceDimShuffle{x,0}(x), y)), z)]"), str(g)) # or not.
init_str_g_inplace = (
"[Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z)]")
init_str_g_noinplace = (
"[Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z)]")
self.assertTrue(str(g) in (init_str_g_inplace, init_str_g_noinplace),
str(g))
opt_str_g_inplace = (
"[Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]")
opt_str_g_noinplace = (
"[Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z)]")
dimshuffle_lift.optimize(g) dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == ("[Elemwise{add,no_inplace}(Elemwise" self.assertTrue(str(g) in (opt_str_g_inplace, opt_str_g_noinplace),
"{add,no_inplace}(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle" str(g))
"{x,0,1}(y)), z)]"), str(g))
def test_add_canonizer_problem0(): def test_add_canonizer_problem0():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论