提交 0dbe46da authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6477 from nouiz/minimum_maximum_grad

Minimum maximum gradient fix
...@@ -1650,9 +1650,11 @@ class Maximum(BinaryScalarOp): ...@@ -1650,9 +1650,11 @@ class Maximum(BinaryScalarOp):
if outputs[0].type in discrete_types: if outputs[0].type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
# This form handle the case when both value are the same.
gx = eq(outputs[0], x) * gz # In that case, gx will be gz, gy will be 0.
gy = eq(outputs[0], y) * gz e = eq(outputs[0], x)
gx = e * gz
gy = (constant(1, dtype=gz.dtype) - e) * gz
return (gx, gy) return (gx, gy)
maximum = Maximum(upcast_out, name='maximum') maximum = Maximum(upcast_out, name='maximum')
...@@ -1686,8 +1688,11 @@ class Minimum(BinaryScalarOp): ...@@ -1686,8 +1688,11 @@ class Minimum(BinaryScalarOp):
if outputs[0].type in discrete_types: if outputs[0].type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
gx = eq(outputs[0], x) * gz # This form handle the case when both value are the same.
gy = eq(outputs[0], y) * gz # In that case, gx will be gz, gy will be 0.
e = eq(outputs[0], x)
gx = e * gz
gy = (constant(1, dtype=gz.dtype) - e) * gz
return (gx, gy) return (gx, gy)
minimum = Minimum(upcast_out, name='minimum') minimum = Minimum(upcast_out, name='minimum')
......
...@@ -4998,10 +4998,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4998,10 +4998,7 @@ class Canonizer(gof.LocalOptimizer):
new = self.merge_num_denum(num, denum) new = self.merge_num_denum(num, denum)
if new.type.dtype != out.type.dtype: if new.type.dtype != out.type.dtype:
# new = T.fill(out, new) new = T.cast(new, out.type.dtype)
elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(
getattr(scalar, out.type.dtype))))
new = elem_op(new)
assert (new.type == out.type) == (not (new.type != out.type)) assert (new.type == out.type) == (not (new.type != out.type))
......
...@@ -812,6 +812,19 @@ MaximumInplaceTester = makeBroadcastTester( ...@@ -812,6 +812,19 @@ MaximumInplaceTester = makeBroadcastTester(
bad_runtime=_bad_runtime_broadcast_binary_normal, bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True) inplace=True)
def test_maximum_minimum_grad():
# Test the discontinuity point.
# We decided that we only pass the gradient to the first input in that case.
x, y = tensor.vectors('xy')
for op in [tensor.maximum, tensor.minimum]:
o = op(x, y)
g = theano.grad(o.sum(), [x, y])
theano.printing.debugprint(g)
f = theano.function([x, y], g)
assert np.allclose(f([1], [1]), [[1], [0]])
MinimumTester = makeBroadcastTester( MinimumTester = makeBroadcastTester(
op=minimum, op=minimum,
expected=lambda *inputs: check_floatX(inputs, np.minimum(*inputs)), expected=lambda *inputs: check_floatX(inputs, np.minimum(*inputs)),
......
...@@ -898,6 +898,25 @@ def test_const_type_in_mul_canonizer(): ...@@ -898,6 +898,25 @@ def test_const_type_in_mul_canonizer():
f1(ival, wval, visbval, hidbval, betaval, aval)) f1(ival, wval, visbval, hidbval, betaval, aval))
def test_cast_in_mul_canonizer():
x, y = tensor.vectors('xy')
m = tensor.minimum(x, y)
o = m.sum()
go = tensor.fill(o, 1)
e = tensor.eq(go, x)
o1 = (1 - e) * go
o2 = e * go
mode = theano.compile.get_default_mode().excluding('fusion').including('fast_run')
f = theano.function([x, y], [o1, o2], mode=mode)
theano.printing.debugprint(f, print_type=True)
nodes = f.maker.fgraph.apply_nodes
assert len([n for n in nodes if isinstance(getattr(n.op, 'scalar_op', None),
theano.scalar.Identity)]) == 0
assert len([n for n in nodes if isinstance(getattr(n.op, 'scalar_op'),
theano.scalar.Cast)]) == 1
f([1], [1])
class test_fusion(unittest.TestCase): class test_fusion(unittest.TestCase):
mode = copy.copy(compile.mode.get_default_mode()) mode = copy.copy(compile.mode.get_default_mode())
_shared = staticmethod(shared) _shared = staticmethod(shared)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论