提交 5e9fdfd3 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed desc/strdesc of DimShuffle

上级 1c9f715b
......@@ -64,7 +64,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e])
assert str(g) == "[DimShuffle{10}(DimShuffle{10}(x))]"
assert str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x))]"
lift_dimshuffle.optimize(g)
assert str(g) == "[x]"
......@@ -72,15 +72,15 @@ class _test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{20x1}(DimShuffle{1x0}(x))]", str(g))
self.failUnless(str(g) == "[InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x))]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[DimShuffle{01xx}(x)]", str(g))
self.failUnless(str(g) == "[InplaceDimShuffle{0,1,x,x}(x)]", str(g))
def test_elim3(self):
x, y, z = inputs()
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{10}(DimShuffle{20x1}(DimShuffle{0x1}(x)))]", str(g))
self.failUnless(str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{0,x,1}(x)))]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[x]", str(g))
......@@ -88,9 +88,9 @@ class _test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs([0]*1, [0]*2, [0]*3)
e = x + y + z
g = Env([x, y, z], [e])
self.failUnless(str(g) == "[Broadcast{Add}(DimShuffle{x01}(Broadcast{Add}(DimShuffle{x0}(x), y)), z)]", str(g))
self.failUnless(str(g) == "[Broadcast{Add}(InplaceDimShuffle{x,0,1}(Broadcast{Add}(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[Broadcast{Add}(Broadcast{Add}(DimShuffle{xx0}(x), DimShuffle{x01}(y)), z)]", str(g))
self.failUnless(str(g) == "[Broadcast{Add}(Broadcast{Add}(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
class _test_cliques(unittest.TestCase):
......
......@@ -105,10 +105,13 @@ class DimShuffle(Op, Viewer):
return {}
def desc(self):
return (self.__class__, tuple(self.new_order, self.inplace))
return (self.__class__, tuple(self.new_order), self.inplace)
def strdesc(self):
return "DimShuffle{%s}" % "".join(str(x) for x in (self.new_order, self.inplace))
if self.inplace:
return "InplaceDimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
else:
return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
def perform(self):
# drop
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论