提交 d3f405db authored 作者: lamblin's avatar lamblin

Merge pull request #998 from goodfeli/fix_switch_grad

Fix Switch.grad
...@@ -1041,15 +1041,13 @@ class Switch(ScalarOp): ...@@ -1041,15 +1041,13 @@ class Switch(ScalarOp):
return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals() return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals()
def grad(self, (cond, ift, iff), (gz, )): def grad(self, (cond, ift, iff), (gz, )):
if ift.type in continuous_types: first_part = switch(cond, gz, 0.)
first_part = switch(cond, gz, 0) second_part = switch(cond, 0., gz)
else:
first_part = None
if iff.type in continuous_types: out = self(cond, ift, iff)
second_part = switch(cond, 0, gz) if out.type.dtype in discrete_types:
else: first_part = 0.
second_part = None second_part = 0.
# cond does affect the elements of the output so it is connected. # cond does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that # For the sake of making the gradient convenient we assume that
......
...@@ -214,6 +214,20 @@ def test_grad_gt(): ...@@ -214,6 +214,20 @@ def test_grad_gt():
g = theano.gradient.grad(z, y) g = theano.gradient.grad(z, y)
assert g.eval({ y : 1. }) == 0. assert g.eval({ y : 1. }) == 0.
def test_grad_switch():
# This is a code snippet from the mailing list
# It caused an assert to be raised due to the
# switch op's grad method not handling integer
# inputs correctly
x = theano.tensor.matrix()
c = theano.tensor.matrix()
s = theano.tensor.switch(c, x, 0)
l = s.sum()
theano.gradient.grad(l, x)
# Testing of Composite is done in tensor/tests/test_opt.py # Testing of Composite is done in tensor/tests/test_opt.py
# in test_fusion, TestCompositeCodegen # in test_fusion, TestCompositeCodegen
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论