提交 ae81700c authored 作者: Frederic Bastien's avatar Frederic Bastien

fix an error that make an optimization fail. Add test for that case.

上级 7f306a06
...@@ -1678,7 +1678,7 @@ def local_sum_div_dimshuffle(node): ...@@ -1678,7 +1678,7 @@ def local_sum_div_dimshuffle(node):
new_new_order = list(ax for i,ax in enumerate(new_order) if i not in axis or ax != 'x') new_new_order = list(ax for i,ax in enumerate(new_order) if i not in axis or ax != 'x')
#print 'new_new_order =', new_new_order #print 'new_new_order =', new_new_order
# Remove useless rebroadcast axes # Remove useless rebroadcast axes
while new_new_order[0] == 'x': while len(new_new_order)>0 and new_new_order[0] == 'x':
del new_new_order[0] del new_new_order[0]
#print 'new_new_order =', new_new_order #print 'new_new_order =', new_new_order
......
...@@ -1581,10 +1581,24 @@ class T_local_sum_dimshuffle(unittest.TestCase): ...@@ -1581,10 +1581,24 @@ class T_local_sum_dimshuffle(unittest.TestCase):
a = T.matrix('a') a = T.matrix('a')
b = T.vector('b') b = T.vector('b')
c = T.tensor3('c') c = T.tensor3('c')
d = T.scalar('d')
sums = [ sums = [
sum(a/d),
sum(a/d.dimshuffle('x','x')),
sum(a/d.dimshuffle('x','x'), axis=0),
sum(a/d.dimshuffle('x','x'), axis=1),
sum(b/d),
sum(b/d.dimshuffle('x')),
sum(c/d),
sum(c/d.dimshuffle('x','x','x')),
sum(c/d.dimshuffle('x','x','x'),axis=0),
sum(c/d.dimshuffle('x','x','x'),axis=1),
sum(c/d.dimshuffle('x','x','x'),axis=2),
sum(a / b, axis=0), sum(a / b, axis=0),
sum(a / b.dimshuffle(0,'x'), axis=1), sum(a / b.dimshuffle(0,'x'), axis=1),
sum(a.dimshuffle(0,1)/ b.dimshuffle(0,'x'), axis=1),
sum(a.dimshuffle(1,0)/ b.dimshuffle(0,'x'), axis=1),
sum(c / a, axis=0), sum(c / a, axis=0),
sum(c / a.dimshuffle(1, 0), axis=0), sum(c / a.dimshuffle(1, 0), axis=0),
sum(c / a.dimshuffle(0,'x',1), axis=1), sum(c / a.dimshuffle(0,'x',1), axis=1),
...@@ -1611,11 +1625,12 @@ class T_local_sum_dimshuffle(unittest.TestCase): ...@@ -1611,11 +1625,12 @@ class T_local_sum_dimshuffle(unittest.TestCase):
for i,s in enumerate(sums): for i,s in enumerate(sums):
print i print i
f = theano.function([a,b,c], s, mode=self.mode) f = theano.function([a,b,c,d], s, mode=self.mode)
theano.printing.debugprint(f) theano.printing.debugprint(f)
g = f.maker.env.toposort() g = f.maker.env.toposort()
#print 'g =', g #print 'g =', g
assert isinstance(g[-1].op.scalar_op, theano.scalar.basic.TrueDiv) assert isinstance(g[-1].op.scalar_op, theano.scalar.basic.TrueDiv)
f([[1,2],[3,4]],[5,6],[[[7,8],[9,10]],[[11,12],[13,14]]],15)
# TODO: # TODO:
# test_local_sum_prod_dimshuffle (a * b * c) # test_local_sum_prod_dimshuffle (a * b * c)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论