提交 cb348444 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #957 from goodfeli/fix_elemwise_grad

fix grad for comparisons
......@@ -843,7 +843,11 @@ class LogicalComparison(BinaryScalarOp):
return [int8]
def grad(self, inputs, output_gradients):
return [None, None]
x, y = inputs
out = self(x, y)
assert str(out.type.dtype).find('int') != -1
return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)]
class FixedLogicalComparison(UnaryScalarOp):
......@@ -854,7 +858,10 @@ class FixedLogicalComparison(UnaryScalarOp):
return [int8]
def grad(self, inputs, output_gradients):
return [None]
x ,= inputs
out = self(x)
assert str(out.type.dtype).find('int') != -1
return [x.zeros_like().astype(theano.config.floatX)]
class LT(LogicalComparison):
......
......@@ -207,6 +207,13 @@ class test_div(unittest.TestCase):
assert isinstance((f/c).owner.op, TrueDiv)
assert isinstance((a/c).owner.op, TrueDiv)
def test_grad_gt():
x = float32(name = 'x')
y = float32(name = 'y')
z = x > y
g = theano.gradient.grad(z, y)
assert g.eval({ y : 1. }) == 0.
# Testing of Composite is done in tensor/tests/test_opt.py
# in test_fusion, TestCompositeCodegen
......
......@@ -690,6 +690,8 @@ class Elemwise(Op):
scalar_ograds = [Scalar(dtype=ograd.type.dtype)()
for ograd in ograds]
scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds)
for igrad in scalar_igrads:
assert igrad is not None
finally:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论