提交 4be4a46c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update inplace_increment1d_slow

上级 68d72ec3
...@@ -1830,9 +1830,13 @@ class AdvancedIncSubtensor1(Op): ...@@ -1830,9 +1830,13 @@ class AdvancedIncSubtensor1(Op):
# broadcasted to fill all relevant rows of `x`. # broadcasted to fill all relevant rows of `x`.
assert y.ndim <= x.ndim # Should be guaranteed by `make_node` assert y.ndim <= x.ndim # Should be guaranteed by `make_node`
if y.ndim == x.ndim: if y.ndim == x.ndim:
assert len(y) == len(idx) if len(y) == 1:
for (j, i) in enumerate(idx): for i in idx:
x[i] += y[j] x[i] += y[0]
else:
assert len(y) == len(idx)
for (j, i) in enumerate(idx):
x[i] += y[j]
else: else:
for i in idx: for i in idx:
x[i] += y x[i] += y
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论