提交 86c0ad90 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Test for local_sum_div_dimshuffle optimization

上级 db18e58c
...@@ -1416,6 +1416,61 @@ def test_local_useless_neq(): ...@@ -1416,6 +1416,61 @@ def test_local_useless_neq():
assert len(topo2)==3 assert len(topo2)==3
assert isinstance(topo2[-1].op,T.Alloc) assert isinstance(topo2[-1].op,T.Alloc)
class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize')
def test_local_sum_div_dimshuffle(self):
a = T.matrix()
b = T.vector()
c = T.tensor3()
sums = [
sum(a / b, axis=0),
sum(b / a, axis=0),
sum(a / b.dimshuffle(0,'x'), axis=1),
sum(b.dimshuffle(0,'x') / a, axis=1),
sum(c / a, axis=0),
sum(a / c, axis=0),
sum(c / a.dimshuffle(0,'x',1), axis=1),
sum(a.dimshuffle(0,'x',1) / c, axis=1),
sum(c / a.dimshuffle(0, 1, 'x'), axis=2),
sum(a.dimshuffle(0, 1, 'x') / c, axis=2),
sum(c / b, axis=0),
sum(b / c, axis=0),
sum(c / b, axis=1),
sum(b / c, axis=1),
sum(c / b, axis=(0,1)),
sum(b / c, axis=(0,1)),
sum(c / b.dimshuffle(0,'x'), axis=0),
sum(b.dimshuffle(0,'x') / c, axis=0),
sum(c / b.dimshuffle(0,'x'), axis=2),
sum(b.dimshuffle(0,'x') / c, axis=2),
sum(c / b.dimshuffle(0,'x'), axis=(0,2)),
sum(b.dimshuffle(0,'x') / c, axis=(0,2)),
sum(c / b.dimshuffle(0,'x','x'), axis=1),
sum(b.dimshuffle(0,'x','x') / c, axis=1),
sum(c / b.dimshuffle(0,'x','x'), axis=2),
sum(b.dimshuffle(0,'x','x') / c, axis=2),
sum(c / b.dimshuffle(0,'x','x'), axis=(1,2)),
sum(b.dimshuffle(0,'x','x') / c, axis=(1,2)),
sum(sum(c, axis=0) / b, axis=0),
sum(b / sum(c, axis=0), axis=0),
sum(sum(c, axis=1) / b, axis=0),
sum(b / sum(c, axis=1), axis=0),
]
for i,s in enumerate(sums):
print i
f = theano.function([a,b,c], s, mode=self.mode)
theano.printing.debugprint(f)
g = f.maker.env.toposort()
#print 'g =', g
assert g[-1].op == T.true_div
# TODO:
# test_local_sum_prod_dimshuffle (a * b * c)
# test_local_sum_divprod_dimshuffle ((a * b) / (c * d))
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论