提交 8f344da3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix python implementation of IncSubtensor

上级 f309c22e
...@@ -1740,41 +1740,30 @@ class IncSubtensor(COp): ...@@ -1740,41 +1740,30 @@ class IncSubtensor(COp):
def decl_view(self): def decl_view(self):
return "PyArrayObject * zview = NULL;" return "PyArrayObject * zview = NULL;"
def perform(self, node, inputs, out_): def perform(self, node, inputs, output_storage):
(out,) = out_ x, y, *flat_indices = inputs
x, y = inputs[:2]
indices = list(reversed(inputs[2:])) flat_indices_iterator = iter(flat_indices)
indices = tuple(
def _convert(entry): (
if isinstance(entry, Type): next(flat_indices_iterator)
return indices.pop() if isinstance(entry, Type)
elif isinstance(entry, slice): else slice(
return slice( None if entry.start is None else next(flat_indices_iterator),
_convert(entry.start), _convert(entry.stop), _convert(entry.step) None if entry.stop is None else next(flat_indices_iterator),
None if entry.step is None else next(flat_indices_iterator),
) )
else: )
return entry for entry in self.idx_list
)
cdata = tuple(map(_convert, self.idx_list))
if len(cdata) == 1:
cdata = cdata[0]
if not self.inplace: if not self.inplace:
x = x.copy() x = x.copy()
sub_x = x.__getitem__(cdata) if self.set_instead_of_inc:
if sub_x.shape: x[indices] = y
# we've sliced out an N-D tensor with N > 0
if not self.set_instead_of_inc:
sub_x += y
else:
# sub_x += -sub_x + y
x.__setitem__(cdata, y)
else: else:
# scalar case x[indices] += y
if not self.set_instead_of_inc: output_storage[0][0] = x
x.__setitem__(cdata, sub_x + y)
else:
x.__setitem__(cdata, y)
out[0] = x
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
# This method delegates much of the work to helper # This method delegates much of the work to helper
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论