提交 9a13989b authored 作者: Kelvin Xu's avatar Kelvin Xu 提交者: Kelvin Xu

added initial opt

上级 96a2c0b6
...@@ -1919,6 +1919,34 @@ def local_set_to_inc_subtensor(node): ...@@ -1919,6 +1919,34 @@ def local_set_to_inc_subtensor(node):
return [advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])] return [advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])]
@register_canonicalize
@register_specialize
@gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_useless_slice(node):
"""
Remove Subtensor/AdvancedSubtensor1 of the from X[0, :] -> X[0]
"""
if isinstance(node.op, Subtensor):
slices = get_idx_list(node.inputs, node.op.idx_list)
last_slice = len(slices)
for s in slices[::-1]:
# check if slice and then check slice indices
if isinstance(s, slice):
if s.start is None and s.stop is None and\
(s.step is None or T.extract_constant(s.step) == 1):
last_slice -= 1
else:
break
# check if we removed something
if last_slice < len(slices):
subtens = Subtensor(make_constant(slices[:last_slice]))
sl_ins = Subtensor.collapse(slices[:last_slice],
lambda x: isinstance(x, T.Variable))
out = subtens(node.inputs[0], *sl_ins)
return [out]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor, AdvancedSubtensor1]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论