提交 ca45a0bc authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2597 from lamblin/fix_bart

Fix in local_reduce_join with axis=None
...@@ -13,6 +13,7 @@ import operator ...@@ -13,6 +13,7 @@ import operator
import sys import sys
import time import time
import traceback import traceback
import warnings
import numpy import numpy
import numpy as N # guys... please don't do this in the library :( import numpy as N # guys... please don't do this in the library :(
...@@ -3901,10 +3902,14 @@ def local_reduce_join(node): ...@@ -3901,10 +3902,14 @@ def local_reduce_join(node):
# The reduction do something about the dtype. # The reduction do something about the dtype.
return return
reduce_axis = node.op.axis
if reduce_axis is None:
reduce_axis = tuple(xrange(node.inputs[0].ndim))
# I put this warning late to don't add extra warning. # I put this warning late to don't add extra warning.
if len(node.op.axis) != 1 or 0 not in node.op.axis: if len(reduce_axis) != 1 or 0 not in reduce_axis:
if theano.config.warn.reduce_join: if theano.config.warn.reduce_join:
_logger.warn(( warnings.warn((
'Your current code is fine, but Theano versions ' 'Your current code is fine, but Theano versions '
'prior to 0.7 (or this development version Sept 2014) ' 'prior to 0.7 (or this development version Sept 2014) '
'might have given an incorrect result for this code. ' 'might have given an incorrect result for this code. '
...@@ -3920,7 +3925,7 @@ def local_reduce_join(node): ...@@ -3920,7 +3925,7 @@ def local_reduce_join(node):
# We add the new check late to don't add extra warning. # We add the new check late to don't add extra warning.
try: try:
join_axis = get_scalar_constant_value(join.inputs[0]) join_axis = get_scalar_constant_value(join.inputs[0])
if join_axis != node.op.axis[0]: if join_axis != reduce_axis[0]:
return return
except NotScalarConstantError: except NotScalarConstantError:
return return
......
...@@ -4447,7 +4447,8 @@ class T_local_reduce(unittest.TestCase): ...@@ -4447,7 +4447,8 @@ class T_local_reduce(unittest.TestCase):
assert isinstance(topo[-1].op, T.Elemwise), out assert isinstance(topo[-1].op, T.Elemwise), out
# Test different axis for the join and the reduction # Test different axis for the join and the reduction
# We must force the dtype, of otherwise, this tests will fail in 32 bit system # We must force the dtype, of otherwise, this tests will fail
# on 32 bit systems
A = theano.shared(numpy.array([1, 2, 3, 4, 5], dtype='int64')) A = theano.shared(numpy.array([1, 2, 3, 4, 5], dtype='int64'))
f = theano.function([], T.sum(T.stack(A, A), axis=0), mode=self.mode) f = theano.function([], T.sum(T.stack(A, A), axis=0), mode=self.mode)
...@@ -4482,6 +4483,17 @@ class T_local_reduce(unittest.TestCase): ...@@ -4482,6 +4483,17 @@ class T_local_reduce(unittest.TestCase):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, T.Elemwise) assert not isinstance(topo[-1].op, T.Elemwise)
# Test that the optimization does not crash in one case where it
# is not applied. Reported at
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
old = theano.config.warn.reduce_join
try:
theano.config.warn.reduce_join = False
out = tensor.sum([vx, vy, vz], axis=None)
f = theano.function([vx, vy, vz], out)
finally:
theano.config.warn.reduce_join = old
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论