提交 6d1d5fa5 authored 作者: Frederic's avatar Frederic

[CRASH] fix crash in the grad of IncSubtensor, when the new value is implicitly broadcated.

上级 14dad2b2
...@@ -1460,7 +1460,8 @@ class IncSubtensor(Op): ...@@ -1460,7 +1460,8 @@ class IncSubtensor(Op):
gx = g_output gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list) gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
if gy.broadcastable != y.broadcastable: if gy.broadcastable != y.broadcastable:
y_broad = (True,) * (gy.ndim - y.ndim) + y.broadcastable y_dim_added = gy.ndim - y.ndim
y_broad = (True,) * y_dim_added + y.broadcastable
assert sum(gy.broadcastable) < sum(y_broad) assert sum(gy.broadcastable) < sum(y_broad)
axis_to_sum = [] axis_to_sum = []
for i in range(gy.ndim): for i in range(gy.ndim):
...@@ -1476,7 +1477,10 @@ class IncSubtensor(Op): ...@@ -1476,7 +1477,10 @@ class IncSubtensor(Op):
assert gy.broadcastable[i] == y_broad[i] assert gy.broadcastable[i] == y_broad[i]
gy = gy.sum(axis=axis_to_sum, keepdims=True) gy = gy.sum(axis=axis_to_sum, keepdims=True)
if gy.ndim != y.ndim: if gy.ndim != y.ndim:
gy = gy.dimshuffle(*range(y.ndim, gy.ndim)) assert gy.ndim > y.ndim
for i in range(y_dim_added):
assert gy.broadcastable[i]
gy = gy.dimshuffle(*range(y_dim_added, gy.ndim))
assert gy.broadcastable == y.broadcastable assert gy.broadcastable == y.broadcastable
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
......
...@@ -154,3 +154,8 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -154,3 +154,8 @@ class Test_inc_subtensor(unittest.TestCase):
(numpy.asarray([[0, 1], [2, 3], [4, 5.]]), (numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
numpy.asarray(9.),)) numpy.asarray(9.),))
# broadcast
utt.verify_grad(
f_slice(2),
(numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
numpy.asarray(9.),))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论