提交 b55b4f4c authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

added support for AdvancedIncSubtensor[1] to local_incsubtensor_of_zeros optimization

上级 348671b1
...@@ -2393,12 +2393,17 @@ compile.optdb.register('local_inplace_incsubtensor1', ...@@ -2393,12 +2393,17 @@ compile.optdb.register('local_inplace_incsubtensor1',
# Register old name # Register old name
@register_canonicalize("local_incsubtensor_of_allocs") @register_canonicalize("local_incsubtensor_of_allocs")
@register_stabilize("local_incsubtensor_of_allocs") @register_stabilize("local_incsubtensor_of_allocs")
@gof.local_optimizer([IncSubtensor]) @gof.local_optimizer([IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1])
def local_incsubtensor_of_zeros(node): def local_incsubtensor_of_zeros(node):
""" """
IncSubtensor(x, zeros, idx) -> x IncSubtensor(x, zeros, idx) -> x
""" """
if isinstance(node.op, IncSubtensor) and not node.op.set_instead_of_inc: if (isinstance(node.op, (IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1)) and
not node.op.set_instead_of_inc):
x = node.inputs[0] x = node.inputs[0]
y = node.inputs[1] y = node.inputs[1]
replace = False replace = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论