提交 d308c69a authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix gh-6470. Fix the grad of maximum and minimum when both input have the same value.

上级 a5b4bb3a
...@@ -1650,9 +1650,11 @@ class Maximum(BinaryScalarOp): ...@@ -1650,9 +1650,11 @@ class Maximum(BinaryScalarOp):
if outputs[0].type in discrete_types: if outputs[0].type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
# This form handle the case when both value are the same.
gx = eq(outputs[0], x) * gz # In that case, gx will be gz, gy will be 0.
gy = eq(outputs[0], y) * gz e = eq(outputs[0], x)
gx = e * gz
gy = (constant(1, dtype=gz.dtype) - e) * gz
return (gx, gy) return (gx, gy)
maximum = Maximum(upcast_out, name='maximum') maximum = Maximum(upcast_out, name='maximum')
...@@ -1686,8 +1688,11 @@ class Minimum(BinaryScalarOp): ...@@ -1686,8 +1688,11 @@ class Minimum(BinaryScalarOp):
if outputs[0].type in discrete_types: if outputs[0].type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
gx = eq(outputs[0], x) * gz # This form handle the case when both value are the same.
gy = eq(outputs[0], y) * gz # In that case, gx will be gz, gy will be 0.
e = eq(outputs[0], x)
gx = e * gz
gy = (constant(1, dtype=gz.dtype) - e) * gz
return (gx, gy) return (gx, gy)
minimum = Minimum(upcast_out, name='minimum') minimum = Minimum(upcast_out, name='minimum')
......
...@@ -812,6 +812,20 @@ MaximumInplaceTester = makeBroadcastTester( ...@@ -812,6 +812,20 @@ MaximumInplaceTester = makeBroadcastTester(
bad_runtime=_bad_runtime_broadcast_binary_normal, bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True) inplace=True)
def test_maximum_minimum_grad():
# Test the discontinuity point.
# We decided that we only pass the gradient to the first input in that case.
x, y = tensor.vectors('xy')
for operator in [tensor.maximum, tensor.minimum]:
o = operator(x, y)
g = theano.grad(o.sum(), [x, y])
theano.printing.debugprint(g)
f = theano.function([x, y], g)
theano.printing.debugprint(f, print_type=True)
assert np.allclose(f([1], [1]), [[1],[0]])
print()
MinimumTester = makeBroadcastTester( MinimumTester = makeBroadcastTester(
op=minimum, op=minimum,
expected=lambda *inputs: check_floatX(inputs, np.minimum(*inputs)), expected=lambda *inputs: check_floatX(inputs, np.minimum(*inputs)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论