提交 7f306a06 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix optimization and test

上级 9927d38b
......@@ -1644,6 +1644,10 @@ def local_sum_div_dimshuffle(node):
if dimension l of the DimShuffle is 'x'.'''
# TODO: extend it to product, and quotient of products
# It does not make much sense now to extend it to the case where the
# dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation.
if isinstance(node.op, T.Sum):
axis = node.op.axis
if axis is None:
......@@ -1653,25 +1657,9 @@ def local_sum_div_dimshuffle(node):
dimshuffled = None
if thing_summed.owner and thing_summed.owner.op == T.true_div:
numerator, denominator = thing_summed.owner.inputs
#This if have bad logic. See its test in tensor/tests/test_opt.py:T_local_sum_dimshuffle
#that fail when we enable this if.
if False and numerator.owner and isinstance(numerator.owner.op, T.DimShuffle):
new_order = numerator.owner.op.new_order
#print 'new_order =', new_order
# check compatibility
compatible_dims = True
for ax in axis:
if len(new_order) <= ax or new_order[ax] != 'x':
compatible_dims = False
break
if compatible_dims:
#print 'getting num out'
new_num = numerator.owner.inputs[0]
return [T.true_div(new_num, node.op(denominator))]
#else:
# print 'incompatible dims:', axis, new_order
if denominator.owner and isinstance(denominator.owner.op, T.DimShuffle):
thing_dimshuffled = denominator.owner.inputs[0]
new_order = denominator.owner.op.new_order
#print 'new_order =', new_order
# check compatibility
......@@ -1683,9 +1671,24 @@ def local_sum_div_dimshuffle(node):
if len(new_order) <= ax or new_order[ax] != 'x':
compatible_dims = False
break
if compatible_dims:
#print 'getting denom out'
new_denom = denominator.owner.inputs[0]
# Keep needed dimensions for new dimshuffle
new_new_order = list(ax for i,ax in enumerate(new_order) if i not in axis or ax != 'x')
#print 'new_new_order =', new_new_order
# Remove useless rebroadcast axes
while new_new_order[0] == 'x':
del new_new_order[0]
#print 'new_new_order =', new_new_order
if all(i == e for i, e in enumerate(new_new_order)):
new_denom = thing_dimshuffled
else:
new_denom = T.DimShuffle(
thing_dimshuffled.type.broadcastable,
new_new_order
)(thing_dimshuffled)
return [T.true_div(node.op(numerator), new_denom)]
#else:
# print 'incompatible dims:', axis, new_order
......
......@@ -1578,56 +1578,45 @@ class T_local_sum_dimshuffle(unittest.TestCase):
self.mode = theano.compile.get_default_mode().including('canonicalize')
def test_local_sum_div_dimshuffle(self):
a = T.matrix()
b = T.vector()
c = T.tensor3()
a = T.matrix('a')
b = T.vector('b')
c = T.tensor3('c')
sums = [
sum(a / b, axis=0),
sum(b / a, axis=0),
sum(a / b.dimshuffle(0,'x'), axis=1),
sum(b.dimshuffle(0,'x') / a, axis=1),
sum(c / a, axis=0),
sum(a / c, axis=0),
sum(c / a.dimshuffle(1, 0), axis=0),
sum(c / a.dimshuffle(0,'x',1), axis=1),
sum(a.dimshuffle(0,'x',1) / c, axis=1),
sum(c / a.dimshuffle(1,'x',0), axis=1),
sum(c / a.dimshuffle(0, 1, 'x'), axis=2),
sum(a.dimshuffle(0, 1, 'x') / c, axis=2),
sum(c / a.dimshuffle(1, 0, 'x'), axis=2),
sum(c / b, axis=0),
sum(b / c, axis=0),
sum(c / b, axis=1),
sum(b / c, axis=1),
sum(c / b, axis=(0,1)),
sum(b / c, axis=(0,1)),
sum(c / b.dimshuffle(0,'x'), axis=0),
sum(b.dimshuffle(0,'x') / c, axis=0),
sum(c / b.dimshuffle(0,'x'), axis=2),
sum(b.dimshuffle(0,'x') / c, axis=2),
sum(c / b.dimshuffle(0,'x'), axis=(0,2)),
sum(b.dimshuffle(0,'x') / c, axis=(0,2)),
sum(c / b.dimshuffle(0,'x','x'), axis=1),
sum(b.dimshuffle(0,'x','x') / c, axis=1),
sum(c / b.dimshuffle(0,'x','x'), axis=2),
sum(b.dimshuffle(0,'x','x') / c, axis=2),
sum(c / b.dimshuffle(0,'x','x'), axis=(1,2)),
sum(b.dimshuffle(0,'x','x') / c, axis=(1,2)),
sum(sum(c, axis=0) / b, axis=0),
sum(b / sum(c, axis=0), axis=0),
sum(sum(c, axis=1) / b, axis=0),
sum(b / sum(c, axis=1), axis=0),
]
rng = numpy.random.RandomState(utt.fetch_seed())
a_val = rng.randn(2,2)
b_val = rng.randn(2)
c_val = rng.randn(2,2,2)
for i,s in enumerate(sums):
print i
f = theano.function([a,b,c], s, mode=self.mode)
theano.printing.debugprint(f)
g = f.maker.env.toposort()
#print 'g =', g
f([[1,1],[1,1]],[1,1],[[[1,1],[1,1]],[[1,1],[1,1]]])
num, denum = s.owner.inputs[0].owner.inputs
if denum.owner and isinstance(denum.owner.op, T.DimShuffle):
assert g[-1].op == T.true_div
assert isinstance(g[-1].op.scalar_op, theano.scalar.basic.TrueDiv)
# TODO:
# test_local_sum_prod_dimshuffle (a * b * c)
# test_local_sum_divprod_dimshuffle ((a * b) / (c * d))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论