提交 d3fc88fe authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed tensor.tests.test_opt

上级 950634de
...@@ -14,13 +14,55 @@ import numpy ...@@ -14,13 +14,55 @@ import numpy
#import scalar_opt #import scalar_opt
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = Tensor(broadcastable = xbc, dtype = 'float64')('x') x = Tensor(broadcastable = xbc, dtype = 'float64')('x')
y = Tensor(broadcastable = ybc, dtype = 'float64')('y') y = Tensor(broadcastable = ybc, dtype = 'float64')('y')
z = Tensor(broadcastable = zbc, dtype = 'float64')('z') z = Tensor(broadcastable = zbc, dtype = 'float64')('z')
return x, y, z return x, y, z
ds = lambda x, y: DimShuffle(x.type.broadcastable, y)(x)
dimshuffle_lift = out2in(local_dimshuffle_lift)
class test_dimshuffle_lift(unittest.TestCase):
def test_double_transpose(self):
x, y, z = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]")
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]")
def test_merge2(self):
x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
def test_elim3(self):
x, y, z = inputs()
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]", str(g))
def test_lift(self):
x, y, z = inputs([False]*1, [False]*2, [False]*3)
e = x + y + z
g = Env([x, y, z], [e])
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
# class _test_inplace_opt(unittest.TestCase): # class _test_inplace_opt(unittest.TestCase):
# def test_straightforward(self): # def test_straightforward(self):
...@@ -60,42 +102,6 @@ def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): ...@@ -60,42 +102,6 @@ def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
# self.failUnless(str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]") # self.failUnless(str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]")
ds = lambda x, y: DimShuffle(x.type.broadcastable, y)(x)
class test_dimshuffle_lift(unittest.TestCase):
def test_double_transpose(self):
x, y, z = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]")
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]")
def test_merge2(self):
x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
def test_elim3(self):
x, y, z = inputs()
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]", str(g))
def test_lift(self):
x, y, z = inputs([False]*1, [False]*2, [False]*3)
e = x + y + z
g = Env([x, y, z], [e])
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论