提交 70b0ce75 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Avoid copy of zeros in AdvancedIncSubtensor1

上级 b2696360
...@@ -1295,12 +1295,26 @@ compile.optdb.register( ...@@ -1295,12 +1295,26 @@ compile.optdb.register(
@node_rewriter([AdvancedIncSubtensor1], inplace=True) @node_rewriter([AdvancedIncSubtensor1], inplace=True)
def local_inplace_AdvancedIncSubtensor1(fgraph, node): def local_inplace_AdvancedIncSubtensor1(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace: if node.op.inplace:
new_op = node.op.clone_inplace() return
new_node = new_op(*node.inputs)
copy_stack_trace(node.outputs, new_node) x, y, idx = node.inputs
return [new_node] if fgraph.has_destroyers([x]):
return False # In this case we can't operate inplace, but if x is just an alloc of zeros
# We're better off duplicating it and then acting on it inplace.
if (
x.owner is not None
and isinstance(x.owner.op, Alloc)
and x.owner.op.value_is_scalar_zero(x.owner.inputs[0])
):
x = x.owner.clone().outputs[0]
else:
return None # Inplace isn't valid
new_op = node.op.clone_inplace()
new_node = new_op(x, y, idx)
copy_stack_trace(node.outputs, new_node)
return [new_node]
compile.optdb.register( compile.optdb.register(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论