提交 8f344da3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix python implementation of IncSubtensor

上级 f309c22e
......@@ -1740,41 +1740,30 @@ class IncSubtensor(COp):
def decl_view(self):
return "PyArrayObject * zview = NULL;"
def perform(self, node, inputs, out_):
(out,) = out_
x, y = inputs[:2]
indices = list(reversed(inputs[2:]))
def _convert(entry):
if isinstance(entry, Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(
_convert(entry.start), _convert(entry.stop), _convert(entry.step)
def perform(self, node, inputs, output_storage):
x, y, *flat_indices = inputs
flat_indices_iterator = iter(flat_indices)
indices = tuple(
(
next(flat_indices_iterator)
if isinstance(entry, Type)
else slice(
None if entry.start is None else next(flat_indices_iterator),
None if entry.stop is None else next(flat_indices_iterator),
None if entry.step is None else next(flat_indices_iterator),
)
else:
return entry
)
for entry in self.idx_list
)
cdata = tuple(map(_convert, self.idx_list))
if len(cdata) == 1:
cdata = cdata[0]
if not self.inplace:
x = x.copy()
sub_x = x.__getitem__(cdata)
if sub_x.shape:
# we've sliced out an N-D tensor with N > 0
if not self.set_instead_of_inc:
sub_x += y
else:
# sub_x += -sub_x + y
x.__setitem__(cdata, y)
if self.set_instead_of_inc:
x[indices] = y
else:
# scalar case
if not self.set_instead_of_inc:
x.__setitem__(cdata, sub_x + y)
else:
x.__setitem__(cdata, y)
out[0] = x
x[indices] += y
output_storage[0][0] = x
def c_code(self, node, name, inputs, outputs, sub):
# This method delegates much of the work to helper
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论