提交 ae81700c authored 作者: Frederic Bastien's avatar Frederic Bastien

fix an error that make an optimization fail. Add test for that case.

上级 7f306a06
......@@ -1678,7 +1678,7 @@ def local_sum_div_dimshuffle(node):
new_new_order = list(ax for i,ax in enumerate(new_order) if i not in axis or ax != 'x')
#print 'new_new_order =', new_new_order
# Remove useless rebroadcast axes
while new_new_order[0] == 'x':
while len(new_new_order)>0 and new_new_order[0] == 'x':
del new_new_order[0]
#print 'new_new_order =', new_new_order
......
......@@ -1581,10 +1581,24 @@ class T_local_sum_dimshuffle(unittest.TestCase):
a = T.matrix('a')
b = T.vector('b')
c = T.tensor3('c')
d = T.scalar('d')
sums = [
sum(a/d),
sum(a/d.dimshuffle('x','x')),
sum(a/d.dimshuffle('x','x'), axis=0),
sum(a/d.dimshuffle('x','x'), axis=1),
sum(b/d),
sum(b/d.dimshuffle('x')),
sum(c/d),
sum(c/d.dimshuffle('x','x','x')),
sum(c/d.dimshuffle('x','x','x'),axis=0),
sum(c/d.dimshuffle('x','x','x'),axis=1),
sum(c/d.dimshuffle('x','x','x'),axis=2),
sum(a / b, axis=0),
sum(a / b.dimshuffle(0,'x'), axis=1),
sum(a.dimshuffle(0,1)/ b.dimshuffle(0,'x'), axis=1),
sum(a.dimshuffle(1,0)/ b.dimshuffle(0,'x'), axis=1),
sum(c / a, axis=0),
sum(c / a.dimshuffle(1, 0), axis=0),
sum(c / a.dimshuffle(0,'x',1), axis=1),
......@@ -1611,11 +1625,12 @@ class T_local_sum_dimshuffle(unittest.TestCase):
for i,s in enumerate(sums):
print i
f = theano.function([a,b,c], s, mode=self.mode)
f = theano.function([a,b,c,d], s, mode=self.mode)
theano.printing.debugprint(f)
g = f.maker.env.toposort()
#print 'g =', g
assert isinstance(g[-1].op.scalar_op, theano.scalar.basic.TrueDiv)
f([[1,2],[3,4]],[5,6],[[[7,8],[9,10]],[[11,12],[13,14]]],15)
# TODO:
# test_local_sum_prod_dimshuffle (a * b * c)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论