提交 2ed28f39 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Split local_useless_AdvancedSubtensor1 from local_useless_subtensor

上级 8e3c356f
...@@ -879,22 +879,15 @@ def local_set_to_inc_subtensor(fgraph, node): ...@@ -879,22 +879,15 @@ def local_set_to_inc_subtensor(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Subtensor, AdvancedSubtensor1]) @local_optimizer([Subtensor])
def local_useless_subtensor(fgraph, node): def local_useless_subtensor(fgraph, node):
""" """Remove `Subtensor` if it takes the full input."""
Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the
AdvancedSubtensor1 case, the full input is taken when the indices are
equivalent to `arange(0, input.shape[0], 1)` using either an explicit
list/vector or the ARange op.
"""
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"): if not hasattr(fgraph, "shape_feature"):
return return
shape_of = fgraph.shape_feature.shape_of shape_of = fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor):
cdata = get_constant_idx( cdata = get_constant_idx(
node.op.idx_list, node.op.idx_list,
node.inputs, node.inputs,
...@@ -939,9 +932,7 @@ def local_useless_subtensor(fgraph, node): ...@@ -939,9 +932,7 @@ def local_useless_subtensor(fgraph, node):
length_pos_shape_i.owner.op, ScalarFromTensor length_pos_shape_i.owner.op, ScalarFromTensor
): ):
length_pos_shape_i = length_pos_shape_i.owner.inputs[0] length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
elif length_pos.owner and isinstance( elif length_pos.owner and isinstance(length_pos.owner.op, TensorFromScalar):
length_pos.owner.op, TensorFromScalar
):
length_pos = length_pos.owner.inputs[0] length_pos = length_pos.owner.inputs[0]
else: else:
# We did not find underlying variables of the same type # We did not find underlying variables of the same type
...@@ -963,10 +954,30 @@ def local_useless_subtensor(fgraph, node): ...@@ -963,10 +954,30 @@ def local_useless_subtensor(fgraph, node):
if length_pos_shape_i != length_pos: if length_pos_shape_i != length_pos:
return False return False
elif idx.stop is None: elif idx.stop is None:
pass continue
else: else:
return False return False
elif isinstance(node.op, AdvancedSubtensor1):
return [node.inputs[0]]
@register_canonicalize
@register_specialize
@local_optimizer([AdvancedSubtensor1])
def local_useless_AdvancedSubtensor1(fgraph, node):
"""Remove `AdvancedSubtensor1` if it takes the full input.
In the `AdvancedSubtensor1` case, the full input is taken when the indices
are equivalent to ``arange(0, input.shape[0], 1)`` using either an explicit
list/vector or the `ARange` `Op`.
"""
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"):
return
shape_of = fgraph.shape_feature.shape_of
# get length of the indexed tensor along the first axis # get length of the indexed tensor along the first axis
try: try:
length = get_scalar_constant_value( length = get_scalar_constant_value(
...@@ -1003,11 +1014,7 @@ def local_useless_subtensor(fgraph, node): ...@@ -1003,11 +1014,7 @@ def local_useless_subtensor(fgraph, node):
return False return False
else: else:
return False return False
else:
return False
# We don't need to copy over any stacktrace here,
# because previous stacktrace should suffice.
return [node.inputs[0]] return [node.inputs[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论