提交 9120fd0d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix error in opt when reduction axis is None

上级 c1209105
...@@ -3901,8 +3901,12 @@ def local_reduce_join(node): ...@@ -3901,8 +3901,12 @@ def local_reduce_join(node):
# The reduction do something about the dtype. # The reduction do something about the dtype.
return return
reduce_axis = node.op.axis
if reduce_axis is None:
reduce_axis = tuple(xrange(node.inputs[0].ndim))
# I put this warning late to don't add extra warning. # I put this warning late to don't add extra warning.
if len(node.op.axis) != 1 or 0 not in node.op.axis: if len(reduce_axis) != 1 or 0 not in reduce_axis:
if theano.config.warn.reduce_join: if theano.config.warn.reduce_join:
_logger.warn(( _logger.warn((
'Your current code is fine, but Theano versions ' 'Your current code is fine, but Theano versions '
...@@ -3920,7 +3924,7 @@ def local_reduce_join(node): ...@@ -3920,7 +3924,7 @@ def local_reduce_join(node):
# We add the new check late to don't add extra warning. # We add the new check late to don't add extra warning.
try: try:
join_axis = get_scalar_constant_value(join.inputs[0]) join_axis = get_scalar_constant_value(join.inputs[0])
if join_axis != node.op.axis[0]: if join_axis != reduce_axis[0]:
return return
except NotScalarConstantError: except NotScalarConstantError:
return return
......
...@@ -4482,6 +4482,17 @@ class T_local_reduce(unittest.TestCase): ...@@ -4482,6 +4482,17 @@ class T_local_reduce(unittest.TestCase):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, T.Elemwise) assert not isinstance(topo[-1].op, T.Elemwise)
# Test that the optimization does not crash in one case where it
# is not applied. Reported at
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
old = theano.config.warn.reduce_join
try:
theano.config.warn.reduce_join = False
out = tensor.sum([vx, vy, vz], axis=None)
f = theano.function([vx, vy, vz], out)
finally:
theano.config.warn.reduce_join = old
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论