提交 3e1e0779 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use clone_inplace() to wrap the operation of making a Inc/SetSubtensor inplace.

上级 81ec8a02
...@@ -2549,8 +2549,7 @@ compile.optdb.register('local_inplace_setsubtensor', ...@@ -2549,8 +2549,7 @@ compile.optdb.register('local_inplace_setsubtensor',
def local_inplace_incsubtensor1(node): def local_inplace_incsubtensor1(node):
""" also work for GpuAdvancedIncSubtensor1 """ """ also work for GpuAdvancedIncSubtensor1 """
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace: if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
new_op = node.op.__class__( new_op = node.op.clone_inplace()
inplace=True, set_instead_of_inc=node.op.set_instead_of_inc)
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
return [new_node] return [new_node]
return False return False
......
...@@ -1808,6 +1808,11 @@ class AdvancedIncSubtensor1(Op): ...@@ -1808,6 +1808,11 @@ class AdvancedIncSubtensor1(Op):
and self.inplace == other.inplace and self.inplace == other.inplace
and self.set_instead_of_inc == other.set_instead_of_inc) and self.set_instead_of_inc == other.set_instead_of_inc)
def clone_inplace(self):
return self.__class__(
inplace=True,
set_instead_of_inc=self.set_instead_of_inc)
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
msg = "inplace" msg = "inplace"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论