提交 8c02d5a6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Re-enable sum optimizations when dtype!=None.

Also add a tests for one specific case. The usual cases should be tested by the regular tests, since the sum nodes' dtype are no longer allowed to be None.
上级 53a2b374
...@@ -2862,8 +2862,7 @@ def local_sum_mul_by_scalar(node): ...@@ -2862,8 +2862,7 @@ def local_sum_mul_by_scalar(node):
""" """
# TODO: if the the thing inside the Sum is a division, # TODO: if the the thing inside the Sum is a division,
# we should get at the numerator.... # we should get at the numerator....
# TODO Implement for sum.dtype != None. if isinstance(node.op, T.Sum):
if isinstance(node.op, T.Sum) and node.op.dtype is None:
thing_summed, = node.inputs thing_summed, = node.inputs
if thing_summed.owner and thing_summed.owner.op == T.mul: if thing_summed.owner and thing_summed.owner.op == T.mul:
terms = thing_summed.owner.inputs terms = thing_summed.owner.inputs
...@@ -2917,9 +2916,7 @@ def local_sum_div_dimshuffle(node): ...@@ -2917,9 +2916,7 @@ def local_sum_div_dimshuffle(node):
# dimshuffle is in the numerator, since elemwise inversion of the # dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation. # denominator would still be needed before the summation.
# TODO Implement for sum.dtype != None. if isinstance(node.op, T.Sum):
if isinstance(node.op, T.Sum) and node.op.dtype is None:
axis = node.op.axis axis = node.op.axis
if axis is None: if axis is None:
axis = range(node.inputs[0].ndim) axis = range(node.inputs[0].ndim)
...@@ -3016,24 +3013,21 @@ def local_sum_all_to_none(node): ...@@ -3016,24 +3013,21 @@ def local_sum_all_to_none(node):
def local_sum_sum(node): def local_sum_sum(node):
""" """
Sum(Sum()) -> Sum Sum(Sum()) -> Sum
Note that currently we only replace sums with default dtypes, to avoid
potential dtype conflict issues.
""" """
if isinstance(node.op, T.Sum) and node.op.dtype is None: if isinstance(node.op, T.Sum):
summed, = node.inputs summed, = node.inputs
out_dtype = node.op.dtype
if len(summed.clients) == 1: if len(summed.clients) == 1:
if (summed.owner and if (summed.owner and
isinstance(summed.owner.op, T.Sum) isinstance(summed.owner.op, T.Sum)):
and summed.owner.op.dtype is None):
if summed.owner.op.axis is None: if summed.owner.op.axis is None:
# special case of local_cut_useless_reduce # special case of local_cut_useless_reduce
return [T.Sum(None)(summed.owner.inputs[0])] return [T.Sum(None, dtype=out_dtype)(summed.owner.inputs[0])]
if node.op.axis is None: if node.op.axis is None:
# we're summing up everything anyway so lets # we're summing up everything anyway so lets
# do it all at once # do it all at once
return [T.Sum(None)(summed.owner.inputs[0])] return [T.Sum(None, dtype=out_dtype)(summed.owner.inputs[0])]
newaxis = list(tuple(summed.owner.op.axis)) newaxis = list(tuple(summed.owner.op.axis))
# figure out which dimensions of the original input # figure out which dimensions of the original input
...@@ -3076,7 +3070,7 @@ def local_sum_sum(node): ...@@ -3076,7 +3070,7 @@ def local_sum_sum(node):
"been fixed) set the theano flag " "been fixed) set the theano flag "
"`warn.sum_sum_bug` to False.") "`warn.sum_sum_bug` to False.")
combined_sum = T.Sum(newaxis) combined_sum = T.Sum(newaxis, dtype=out_dtype)
return [combined_sum(summed.owner.inputs[0])] return [combined_sum(summed.owner.inputs[0])]
...@@ -3095,8 +3089,7 @@ def local_cut_useless_reduce(node): ...@@ -3095,8 +3089,7 @@ def local_cut_useless_reduce(node):
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_sum_alloc(node): def local_sum_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)""" """ sum(alloc(constant,shapes...)) => constant*prod(shapes)"""
# TODO Implement for sum.dtype != None if isinstance(node.op, T.Sum):
if isinstance(node.op, T.Sum) and node.op.dtype is None:
summed, = node.inputs summed, = node.inputs
if summed.owner and isinstance(summed.owner.op, T.Alloc): if summed.owner and isinstance(summed.owner.op, T.Alloc):
input = summed.owner.inputs[0] input = summed.owner.inputs[0]
......
...@@ -3150,6 +3150,20 @@ class T_local_sum(unittest.TestCase): ...@@ -3150,6 +3150,20 @@ class T_local_sum(unittest.TestCase):
finally: finally:
config.on_opt_error = backup config.on_opt_error = backup
def test_local_sum_sum_dtype(self):
"""
Test that local_sum_sum works when specifying dtypes manually.
"""
x = tensor.tensor3(dtype='int8')
y = x.sum(axis=0, dtype='int32').sum(axis=1, dtype='int64')
backup = config.on_opt_error
config.on_opt_error = 'raise'
try:
# This compilation would fail prior to fix.
f = theano.function([x], y)
finally:
config.on_opt_error = backup
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论