提交 8ce10c32 authored 作者: Frederic's avatar Frederic

Fix test in DebugMode. Make GpuAdvancedIncSubtensor1 allow to broadcast

上级 2f82eef2
...@@ -2604,10 +2604,10 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2604,10 +2604,10 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
out, = out_ out, = out_
if not self.inplace: if not self.inplace:
x = x.copy() x = x.copy()
assert y.ndim <= x.ndim # Should be guaranteed by `make_node`
if self.set_instead_of_inc: if self.set_instead_of_inc:
# CudaNdarray __setitem__ doesn't do broadcast nor support # CudaNdarray __setitem__ doesn't do broadcast nor support
# list of index. # list of index.
assert y.ndim <= x.ndim # Should be guaranteed by `make_node`
if y.ndim == x.ndim: if y.ndim == x.ndim:
assert len(y) == len(idx) assert len(y) == len(idx)
for (j, i) in enumerate(idx): for (j, i) in enumerate(idx):
...@@ -2619,11 +2619,15 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2619,11 +2619,15 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
# If `y` has as many dimensions as `x`, then we want to iterate # If `y` has as many dimensions as `x`, then we want to iterate
# jointly on `x` and `y`. Otherwise, it means `y` should be # jointly on `x` and `y`. Otherwise, it means `y` should be
# broadcasted to fill all relevant rows of `x`. # broadcasted to fill all relevant rows of `x`.
assert y.ndim <= x.ndim # Should be guaranteed by `make_node`
if y.ndim == x.ndim: if y.ndim == x.ndim:
assert len(y) == len(idx) if len(y) == 1:
for (j, i) in enumerate(idx): # Allow broadcasting of y[0]
x[i] += y[j] for i in idx:
x[i] += y[0]
else:
assert len(y) == len(idx)
for (j, i) in enumerate(idx):
x[i] += y[j]
else: else:
for i in idx: for i in idx:
x[i] += y x[i] += y
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论