提交 91966e85 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in grad of discrete Switch

上级 848ce199
......@@ -1598,8 +1598,8 @@ class Switch(ScalarOp):
second_part = switch(cond, 0.0, gz)
if outputs[0].type in discrete_types:
first_part = 0.0
second_part = 0.0
first_part = ift.zeros_like(config.floatX)
second_part = iff.zeros_like(config.floatX)
# cond does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
......
......@@ -428,6 +428,13 @@ def test_grad_switch():
pytensor.gradient.grad(l, x)
# Bug reported in https://github.com/pymc-devs/pytensor/issues/331
x = matrix(dtype=int)
s = pytensor.tensor.switch(0, x, -x)
l = s.sum()
pytensor.gradient.grad(l, x)
def test_grad_identity():
# Check that the grad method of Identity correctly handles int dytpes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论