提交 11f88f7c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Optimizations for set_subtensor and inc_subtensor when applied on alloc or

zeros.
上级 08c2b2e4
...@@ -1620,6 +1620,73 @@ def local_inplace_incsubtensor1(node): ...@@ -1620,6 +1620,73 @@ def local_inplace_incsubtensor1(node):
compile.optdb.register('local_inplace_incsubtensor1', TopoOptimizer(local_inplace_incsubtensor1, compile.optdb.register('local_inplace_incsubtensor1', TopoOptimizer(local_inplace_incsubtensor1,
failure_callback=TopoOptimizer.warn_inplace), 60, 'fast_run', 'inplace') #DEBUG failure_callback=TopoOptimizer.warn_inplace), 60, 'fast_run', 'inplace') #DEBUG
@register_canonicalize
@register_stabilize
@gof.local_optimizer([None])
def local_incsubtensor_of_allocs(node):
if isinstance(node.op, T.IncSubtensor) and not node.op.set_instead_of_inc:
x = node.inputs[0]
y = node.inputs[1]
replace = False
if y.owner and isinstance(y.owner.op, T.Alloc):
try:
val = get_constant_value(y.owner.inputs[0])
if numpy.all(val == 0):
replace = True
except TypeError:
pass
if isinstance(y, T.TensorConstant) and (y.tag.unique_value == 0):
replace = True
if replace:
return [x]
else:
return False
@register_canonicalize
@register_stabilize
@gof.local_optimizer([None])
def local_setsubtensor_of_allocs(node):
if isinstance(node.op, T.IncSubtensor) and node.op.set_instead_of_inc:
x = node.inputs[0]
y = node.inputs[1]
replace_x = None
replace_y = None
if x.owner and isinstance(x.owner.op, T.Alloc):
try:
val = get_constant_value(x.owner.inputs[0])
assert val.size == 1
replace_x = val
except (TypeError, AssertionError):
replace_x = x.owner.inputs[0]
if isinstance(x, T.TensorConstant) and (x.tag.unique_value is not
None):
replace_x = x.tag.unique_value
if y.owner and isinstance(y.owner.op, T.Alloc):
try:
val = get_constant_value(y.owner.inputs[0])
assert val.size == 1
replace_y = val
except (TypeError, AssertionError):
replace_y = y.owner.inputs[0]
if isinstance(y, T.TensorConstant) and (y.tag.unique_value is not
None):
replace_y = y.tag.unique_value
if (replace_x == replace_y and
replace_x is not None and
replace_y is not None):
return [x]
else:
return False
#################### ####################
# Rebroadcast opts # # Rebroadcast opts #
#################### ####################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论