提交 f0fdc95d authored 作者: Frederic's avatar Frederic

Re-enable the grad on the switch condition as it was working before the grad change.

上级 15725e30
......@@ -1051,7 +1051,12 @@ class Switch(ScalarOp):
else:
second_part = None
return (None, first_part, second_part)
# cond does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
condition_grad = cond.zeros_like().astype(theano.config.floatX)
return (condition_grad, first_part, second_part)
def output_types(self, (cond_t, ift_t, iff_t)):
return upcast_out(ift_t, iff_t)
......
......@@ -39,7 +39,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
itensor3, Tile, AdvancedIncSubtensor)
itensor3, Tile, AdvancedIncSubtensor, switch)
from theano.tests import unittest_tools as utt
from theano.printing import debugprint
......@@ -618,6 +618,36 @@ SubInplaceTester = makeBroadcastTester(op=inplace.sub_inplace,
grad = _grad_broadcast_binary_normal,
inplace = True)
SwitchTester = makeBroadcastTester(
op=switch,
expected=numpy.where,
good=dict(all_true=(numpy.asarray(1, dtype=config.floatX),
rand(4, 5), rand(4, 5)),
false_true=(numpy.asarray(0, dtype=config.floatX),
rand(4, 5), rand(4, 5)),
mixed=(randint_ranged(0, 1, (4, 5)),
rand(4, 5), rand(4, 5))
),
bad_build=dict(all_true=(numpy.asarray(1, dtype=config.floatX),
rand(4, 5))),
bad_runtime=dict(all_true=(numpy.asarray(1, dtype=config.floatX),
rand(3, 5), rand(4, 5)),
false_true=(numpy.asarray(0, dtype=config.floatX),
rand(4, 6), rand(4, 5)),
),
# We suppose that cond+eps do not switch branch in switch.grad()
# So we can't call verify_grad with cond 0.
grad=dict(all_true=(numpy.asarray(1, dtype=config.floatX),
rand(4, 5), rand(4, 5)),
# false_true=(numpy.asarray(0, dtype=config.floatX),
# rand(4, 5), rand(4, 5)),
# mixed=(randint_ranged(0, 1, (4, 5)).astype(config.floatX),
# rand(4, 5), rand(4, 5))
),
)
MaximumTester = makeBroadcastTester(op=maximum,
expected = lambda *inputs: check_floatX(inputs, numpy.maximum(*inputs)),
good = _good_broadcast_binary_normal,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论