提交 6fc0f59a authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix IfElse.grad when true or false condition contain ints. fix gh-4471

上级 200babca
...@@ -213,17 +213,21 @@ class IfElse(PureOp): ...@@ -213,17 +213,21 @@ class IfElse(PureOp):
gpu=self.gpu, gpu=self.gpu,
name=nw_name_f) name=nw_name_f)
if_true = ([ins[0]] + grads + [theano.tensor.zeros_like(t) # The grads can have a different type then the inputs.
for t in ts]) # As all condition must have the same dtype, we must
if_false = ([ins[0]] + [theano.tensor.zeros_like(f) dtype = grads[0].dtype
for f in fs] + grads) if_true = ([ins[0]] +
grads +
[theano.tensor.zeros_like(t, dtype=dtype) for t in ts])
if_false = ([ins[0]] +
[theano.tensor.zeros_like(f, dtype=dtype) for f in fs] +
grads)
condition = ins[0] condition = ins[0]
# condition does affect the elements of the output so it is connected. # condition does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that # For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition # condition + epsilon always triggers the same branch as condition
condition_grad = condition.zeros_like().astype(theano.config.floatX) condition_grad = condition.zeros_like().astype(theano.config.floatX)
return ([condition_grad] + return ([condition_grad] +
if_true_op(*if_true, **dict(return_list=True)) + if_true_op(*if_true, **dict(return_list=True)) +
if_false_op(*if_false, **dict(return_list=True))) if_false_op(*if_false, **dict(return_list=True)))
......
...@@ -482,6 +482,21 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -482,6 +482,21 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
finally: finally:
theano.config.compute_test_value = backup theano.config.compute_test_value = backup
def test_grad_int_value(self):
w = theano.shared(numpy.random.rand(10))
b = theano.shared(numpy.random.rand())
params = [w, b]
x = tensor.vector()
y = tensor.scalar()
score = w.dot(x) + b
correct = (score * y > 0)
loss = ifelse(correct, 0, 1)
updates = [(param, param - 0.5 * tensor.grad(cost=loss, wrt=param))
for param in params]
if __name__ == '__main__': if __name__ == '__main__':
print(' Use nosetests to run these tests ') print(' Use nosetests to run these tests ')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论