提交 1d54925f authored 作者: Shawn Tan's avatar Shawn Tan

Added check in case flattened array is 1d.

上级 e1b2dc23
...@@ -679,6 +679,20 @@ class GpuAdvancedIncSubtensor(HideC, tensor.AdvancedIncSubtensor): ...@@ -679,6 +679,20 @@ class GpuAdvancedIncSubtensor(HideC, tensor.AdvancedIncSubtensor):
# build the indices and use it # build the indices and use it
index = idx_[p:] + [slice(None)] * (len(x_flat.shape) - len(idx_[p:]) - 1) index = idx_[p:] + [slice(None)] * (len(x_flat.shape) - len(idx_[p:]) - 1)
take_idx = sum(i * s for i, s in zip(nidx, strides)) take_idx = sum(i * s for i, s in zip(nidx, strides))
if index == []:
for j, i in enumerate(take_idx.flatten()):
if y_flat.shape == ():
val = y_flat
else:
val = y_flat[j]
tmp = pygpu.elemwise.elemwise2(
x_flat[i], '+', val, x_flat[i],
broadcast=True,
convert_f16=True
)
x_flat.__setitem__(i, tmp)
else:
k = get_iadd(node.inputs[0], node.inputs[1]) k = get_iadd(node.inputs[0], node.inputs[1])
if x_flat.shape[-len(y_flat.shape):] == y_flat.shape or y_flat.shape == (): if x_flat.shape[-len(y_flat.shape):] == y_flat.shape or y_flat.shape == ():
# y_flat has to be broadcast over axes of x_flat[i] # y_flat has to be broadcast over axes of x_flat[i]
...@@ -694,6 +708,7 @@ class GpuAdvancedIncSubtensor(HideC, tensor.AdvancedIncSubtensor): ...@@ -694,6 +708,7 @@ class GpuAdvancedIncSubtensor(HideC, tensor.AdvancedIncSubtensor):
convert_f16=True convert_f16=True
) )
x_flat[i].__setitem__(index, tmp) x_flat[i].__setitem__(index, tmp)
else: else:
# y_flat's first axis corresponds to first exist of x_flat # y_flat's first axis corresponds to first exist of x_flat
for j, i in enumerate(take_idx.flatten()): for j, i in enumerate(take_idx.flatten()):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论