提交 0bd8eee5 authored 作者: Frederic's avatar Frederic

Don't generate too much warning (and also fix a failure in travis tests)

上级 ff45b980
...@@ -3487,21 +3487,6 @@ def local_reduce_join(node): ...@@ -3487,21 +3487,6 @@ def local_reduce_join(node):
if T.extract_constant(join.inputs[0]) != 0: if T.extract_constant(join.inputs[0]) != 0:
return return
if len(node.op.axis) != 1 or 0 not in node.op.axis:
if theano.config.warn.reduce_join:
_logger.warn((
'Your current code is fine, but Theano versions '
'prior to 0.7 (or this development version Sept 2014) '
'might have given an incorrect result. '
'To disable this warning, set the Theano flag '
'warn.reduce_join to False.'))
return
try:
join_axis = get_scalar_constant_value(join.inputs[0])
if join_axis != node.op.axis[0]:
return
except NotScalarConstantError:
return
if isinstance(node.op.scalar_op, (scalar.Maximum, scalar.Minimum)): if isinstance(node.op.scalar_op, (scalar.Maximum, scalar.Minimum)):
#Support only 2 inputs for now #Support only 2 inputs for now
if len(join.inputs) != 3: if len(join.inputs) != 3:
...@@ -3522,9 +3507,31 @@ def local_reduce_join(node): ...@@ -3522,9 +3507,31 @@ def local_reduce_join(node):
return return
new_inp.append(inp.inputs[0]) new_inp.append(inp.inputs[0])
ret = Elemwise(node.op.scalar_op)(*new_inp) ret = Elemwise(node.op.scalar_op)(*new_inp)
if ret.dtype == node.outputs[0].dtype:
if ret.dtype != node.outputs[0].dtype:
# The reduction do something about the dtype.
return
# I put this warning late to don't add extra warning.
if len(node.op.axis) != 1 or 0 not in node.op.axis:
if theano.config.warn.reduce_join:
_logger.warn((
'Your current code is fine, but Theano versions '
'prior to 0.7 (or this development version Sept 2014) '
'might have given an incorrect result. '
'To disable this warning, set the Theano flag '
'warn.reduce_join to False.'))
return
# We add the new check late to don't add extra warning.
try:
join_axis = get_scalar_constant_value(join.inputs[0])
if join_axis != node.op.axis[0]:
return
except NotScalarConstantError:
return
return [ret] return [ret]
#else the reduction do something about the dtype.
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论