提交 d8021fcc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in gradient of `set_subtensor`

上级 91046b69
...@@ -1902,7 +1902,7 @@ def _sum_grad_over_bcasted_dims(x, gx): ...@@ -1902,7 +1902,7 @@ def _sum_grad_over_bcasted_dims(x, gx):
if gx.broadcastable != x.broadcastable: if gx.broadcastable != x.broadcastable:
x_dim_added = gx.ndim - x.ndim x_dim_added = gx.ndim - x.ndim
x_broad = (True,) * x_dim_added + x.broadcastable x_broad = (True,) * x_dim_added + x.broadcastable
assert sum(gx.broadcastable) < sum(x_broad) assert sum(gx.broadcastable) <= sum(x_broad)
axis_to_sum = [] axis_to_sum = []
for i in range(gx.ndim): for i in range(gx.ndim):
if gx.broadcastable[i] is False and x_broad[i] is True: if gx.broadcastable[i] is False and x_broad[i] is True:
......
...@@ -1593,6 +1593,15 @@ class TestIncSubtensor: ...@@ -1593,6 +1593,15 @@ class TestIncSubtensor:
), ),
) )
# Broadcastable leading dim
utt.verify_grad(
f_slice(slice(None, None), slice(1, 3)),
(
np.asarray([0, 1, 2, 3, 4, 5.0])[None, ...],
np.asarray([9, 9.0]),
),
)
class TestIncSubtensor1: class TestIncSubtensor1:
def setup_method(self): def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论