提交 8c0ea6ec authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

make local_useless_subtensor optimization support AdvancedSubtensor1

上级 8e85dbab
...@@ -1917,16 +1917,18 @@ def local_set_to_inc_subtensor(node): ...@@ -1917,16 +1917,18 @@ def local_set_to_inc_subtensor(node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_useless_subtensor(node): def local_useless_subtensor(node):
""" """
Remove Subtensor if it takes the full input Remove Subtensor if it takes the full input
""" """
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'):
return
shape_of = node.fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'):
return
shape_of = node.fgraph.shape_feature.shape_of
cdata = node.op.get_constant_idx(node.inputs, allow_partial=True) cdata = node.op.get_constant_idx(node.inputs, allow_partial=True)
for pos, idx in enumerate(cdata): for pos, idx in enumerate(cdata):
if not isinstance(idx, slice): if not isinstance(idx, slice):
...@@ -1985,8 +1987,41 @@ def local_useless_subtensor(node): ...@@ -1985,8 +1987,41 @@ def local_useless_subtensor(node):
pass pass
else: else:
return False return False
elif isinstance(node.op, AdvancedSubtensor1):
# get length of the indexed tensor along the first axis
try:
length = get_scalar_constant_value(shape_of[node.inputs[0]][0])
except NotScalarConstantError:
pass
# get index (which must be a vector by definition)
idx = node.inputs[1]
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
# this optimization
if isinstance(idx, T.Constant):
idx = idx.value
if len(idx) != length:
return False
if numpy.any(idx != numpy.arange(length)):
return False
elif idx.owner is not None and isinstance(idx.owner.op, T.ARange):
try:
start, stop, step = map(get_scalar_constant_value,
idx.owner.inputs)
except NotScalarConstantError:
return False
if start != 0:
return False
if stop != length:
return False
if step != 1:
return False
else:
return False
return [node.inputs[0]] return [node.inputs[0]]
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论