提交 c78d8db1 authored 作者: abergeron's avatar abergeron

Merge pull request #2110 from nouiz/bug_reduction_join

[BUG] Important bug fix in an optimization.
...@@ -422,6 +422,19 @@ AddConfigVar('warn.signal_conv2d_interface', ...@@ -422,6 +422,19 @@ AddConfigVar('warn.signal_conv2d_interface',
BoolParam(warn_default('0.7')), BoolParam(warn_default('0.7')),
in_c_key=False) in_c_key=False)
AddConfigVar('warn.reduce_join',
('Your current code is fine, but Theano versions '
'prior to 0.7 (or this development version) '
'might have given an incorrect result. '
'To disable this warning, set the Theano flag '
'warn.reduce_join to False. The problem was an '
'optimization that modify the pattern '
'"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", '
'did not checked the reduction axis. So if the '
'reduction axis is not 0, you got wrong answer.'),
BoolParam(warn_default('0.7')),
in_c_key=False)
AddConfigVar('compute_test_value', AddConfigVar('compute_test_value',
("If 'True', Theano will run each op at graph build time, using " ("If 'True', Theano will run each op at graph build time, using "
"Constants, SharedVariables and the tag 'test_value' as inputs " "Constants, SharedVariables and the tag 'test_value' as inputs "
......
...@@ -3483,10 +3483,14 @@ ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any, ...@@ -3483,10 +3483,14 @@ ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any,
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce @register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
@gof.local_optimizer(ALL_REDUCE) @gof.local_optimizer(ALL_REDUCE)
def local_reduce_join(node): def local_reduce_join(node):
"""Reduce{scalar.op}(Join(a, b), axis=0) -> Elemwise{scalar.op}(a, b) """Reduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
:note: supported scalar.op are Maximum, Mimimum in some cases and :note: supported scalar.op are Maximum, Mimimum in some cases and
Add and Mul in all cases. Add and Mul in all cases.
:note: Currently we must reduce on axis 0. It is probably
extensible to the case where we join and reduce on the same
set of axis.
""" """
if (isinstance(node.op, T.CAReduce) and if (isinstance(node.op, T.CAReduce) and
...@@ -3498,7 +3502,7 @@ def local_reduce_join(node): ...@@ -3498,7 +3502,7 @@ def local_reduce_join(node):
return return
if isinstance(node.op.scalar_op, (scalar.Maximum, scalar.Minimum)): if isinstance(node.op.scalar_op, (scalar.Maximum, scalar.Minimum)):
#Support only 2 inputs for now # Support only 2 inputs for now
if len(join.inputs) != 3: if len(join.inputs) != 3:
return return
elif not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul)): elif not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul)):
...@@ -3517,9 +3521,36 @@ def local_reduce_join(node): ...@@ -3517,9 +3521,36 @@ def local_reduce_join(node):
return return
new_inp.append(inp.inputs[0]) new_inp.append(inp.inputs[0])
ret = Elemwise(node.op.scalar_op)(*new_inp) ret = Elemwise(node.op.scalar_op)(*new_inp)
if ret.dtype == node.outputs[0].dtype:
return [ret] if ret.dtype != node.outputs[0].dtype:
#else the reduction do something about the dtype. # The reduction do something about the dtype.
return
# I put this warning late to don't add extra warning.
if len(node.op.axis) != 1 or 0 not in node.op.axis:
if theano.config.warn.reduce_join:
_logger.warn((
'Your current code is fine, but Theano versions '
'prior to 0.7 (or this development version Sept 2014) '
'might have given an incorrect result for this code. '
'To disable this warning, set the Theano flag '
'warn.reduce_join to False. The problem was an '
'optimization that modify the pattern '
'"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", '
'did not checked the reduction axis. So if the '
'reduction axis is not 0, you got wrong answer.'
))
return
# We add the new check late to don't add extra warning.
try:
join_axis = get_scalar_constant_value(join.inputs[0])
if join_axis != node.op.axis[0]:
return
except NotScalarConstantError:
return
return [ret]
@register_canonicalize('fast_compile') @register_canonicalize('fast_compile')
......
...@@ -75,6 +75,11 @@ class RandomStateType(gof.Type): ...@@ -75,6 +75,11 @@ class RandomStateType(gof.Type):
else: else:
raise NotImplementedError() raise NotImplementedError()
return size return size
@staticmethod
def may_share_memory(a, b):
return a is b
# Register RandomStateType's C code for ViewOp. # Register RandomStateType's C code for ViewOp.
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
RandomStateType, RandomStateType,
......
...@@ -3858,7 +3858,7 @@ class T_local_reduce(unittest.TestCase): ...@@ -3858,7 +3858,7 @@ class T_local_reduce(unittest.TestCase):
x = numpy.asarray([[1, 0], [3, 4]], dtype=config.floatX) x = numpy.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = numpy.asarray([[4, 0], [2, 1]], dtype=config.floatX) y = numpy.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = numpy.asarray([[5, 0], [1, 2]], dtype=config.floatX) z = numpy.asarray([[5, 0], [1, 2]], dtype=config.floatX)
# Test different reduction scalar operation
for out, res in [ for out, res in [
(T.max((vx, vy), 0), numpy.max((x, y), 0)), (T.max((vx, vy), 0), numpy.max((x, y), 0)),
(T.min((vx, vy), 0), numpy.min((x, y), 0)), (T.min((vx, vy), 0), numpy.min((x, y), 0)),
...@@ -3873,6 +3873,41 @@ class T_local_reduce(unittest.TestCase): ...@@ -3873,6 +3873,41 @@ class T_local_reduce(unittest.TestCase):
assert len(topo) <= 2, out assert len(topo) <= 2, out
assert isinstance(topo[-1].op, T.Elemwise), out assert isinstance(topo[-1].op, T.Elemwise), out
# Test different axis for the join and the reduction
A = theano.shared(numpy.array([1, 2, 3, 4, 5]))
f = theano.function([], T.sum(T.stack(A, A), axis=0), mode=self.mode)
assert numpy.allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-1].op, T.Elemwise)
# Test a case that was bugged in a old Theano bug
try:
old = theano.config.warn.reduce_join
theano.config.warn.reduce_join = False
f = theano.function([], T.sum(T.stack(A, A), axis=1),
mode=self.mode)
finally:
theano.config.warn.reduce_join = old
assert numpy.allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, T.Elemwise)
# This case could be optimized
A = theano.shared(numpy.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = theano.function([], T.sum(T.concatenate((A, A), axis=1), axis=1),
mode=self.mode)
assert numpy.allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, T.Elemwise)
A = theano.shared(numpy.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = theano.function([], T.sum(T.concatenate((A, A), axis=1), axis=0),
mode=self.mode)
assert numpy.allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, T.Elemwise)
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论