提交 2ed28f39 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Split local_useless_AdvancedSubtensor1 from local_useless_subtensor

上级 8e3c356f
......@@ -879,22 +879,15 @@ def local_set_to_inc_subtensor(fgraph, node):
@register_canonicalize
@register_specialize
@local_optimizer([Subtensor, AdvancedSubtensor1])
@local_optimizer([Subtensor])
def local_useless_subtensor(fgraph, node):
"""
Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the
AdvancedSubtensor1 case, the full input is taken when the indices are
equivalent to `arange(0, input.shape[0], 1)` using either an explicit
list/vector or the ARange op.
"""
"""Remove `Subtensor` if it takes the full input."""
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"):
return
shape_of = fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor):
cdata = get_constant_idx(
node.op.idx_list,
node.inputs,
......@@ -939,9 +932,7 @@ def local_useless_subtensor(fgraph, node):
length_pos_shape_i.owner.op, ScalarFromTensor
):
length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
elif length_pos.owner and isinstance(
length_pos.owner.op, TensorFromScalar
):
elif length_pos.owner and isinstance(length_pos.owner.op, TensorFromScalar):
length_pos = length_pos.owner.inputs[0]
else:
# We did not find underlying variables of the same type
......@@ -963,10 +954,30 @@ def local_useless_subtensor(fgraph, node):
if length_pos_shape_i != length_pos:
return False
elif idx.stop is None:
pass
continue
else:
return False
elif isinstance(node.op, AdvancedSubtensor1):
return [node.inputs[0]]
@register_canonicalize
@register_specialize
@local_optimizer([AdvancedSubtensor1])
def local_useless_AdvancedSubtensor1(fgraph, node):
"""Remove `AdvancedSubtensor1` if it takes the full input.
In the `AdvancedSubtensor1` case, the full input is taken when the indices
are equivalent to ``arange(0, input.shape[0], 1)`` using either an explicit
list/vector or the `ARange` `Op`.
"""
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"):
return
shape_of = fgraph.shape_feature.shape_of
# get length of the indexed tensor along the first axis
try:
length = get_scalar_constant_value(
......@@ -1003,11 +1014,7 @@ def local_useless_subtensor(fgraph, node):
return False
else:
return False
else:
return False
# We don't need to copy over any stacktrace here,
# because previous stacktrace should suffice.
return [node.inputs[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论