提交 2ed28f39 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Split local_useless_AdvancedSubtensor1 from local_useless_subtensor

上级 8e3c356f
...@@ -879,135 +879,142 @@ def local_set_to_inc_subtensor(fgraph, node): ...@@ -879,135 +879,142 @@ def local_set_to_inc_subtensor(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Subtensor, AdvancedSubtensor1]) @local_optimizer([Subtensor])
def local_useless_subtensor(fgraph, node): def local_useless_subtensor(fgraph, node):
""" """Remove `Subtensor` if it takes the full input."""
Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the
AdvancedSubtensor1 case, the full input is taken when the indices are
equivalent to `arange(0, input.shape[0], 1)` using either an explicit
list/vector or the ARange op.
"""
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"): if not hasattr(fgraph, "shape_feature"):
return return
shape_of = fgraph.shape_feature.shape_of shape_of = fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor): cdata = get_constant_idx(
cdata = get_constant_idx( node.op.idx_list,
node.op.idx_list, node.inputs,
node.inputs, allow_partial=True,
allow_partial=True, only_process_constants=True,
only_process_constants=True, )
) for pos, idx in enumerate(cdata):
for pos, idx in enumerate(cdata): if not isinstance(idx, slice):
if not isinstance(idx, slice): # If idx is not a slice, this means we remove this dimension
# If idx is not a slice, this means we remove this dimension # from the output, so the subtensor is not useless
# from the output, so the subtensor is not useless return False
return False if idx.start is not None and idx.start != 0:
if idx.start is not None and idx.start != 0: # If the start of the slice is different from 0, or is a
# If the start of the slice is different from 0, or is a # variable, then we assume the subtensor is not useless
# variable, then we assume the subtensor is not useless return False
return False if idx.step is not None and idx.step != 1:
if idx.step is not None and idx.step != 1: # If we are going backwards, or skipping elements, then this
# If we are going backwards, or skipping elements, then this # is not a useless subtensor
# is not a useless subtensor
return False
for pos, idx in enumerate(cdata):
length_pos = shape_of[node.inputs[0]][pos]
if isinstance(idx.stop, (int, np.integer)):
length_pos_data = sys.maxsize
try:
length_pos_data = get_scalar_constant_value(
length_pos, only_process_constants=True
)
except NotScalarConstantError:
pass
if idx.stop < length_pos_data:
return False
elif isinstance(idx.stop, Variable):
length_pos_shape_i = idx.stop
# length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent
# the same underlying variable.
if length_pos_shape_i.owner and isinstance(
length_pos_shape_i.owner.op, ScalarFromTensor
):
length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
elif length_pos.owner and isinstance(
length_pos.owner.op, TensorFromScalar
):
length_pos = length_pos.owner.inputs[0]
else:
# We did not find underlying variables of the same type
return False
# The type can be different: int32 vs int64. length_pos
# should always be int64 as that is what the shape
# tracker keep. Subtensor accept any scalar int{8,16,32,64}
# as index type.
assert str(length_pos.type.dtype) == "int64"
assert str(length_pos_shape_i.type.dtype) in [
"int8",
"int16",
"int32",
"int64",
]
# length_pos_shape_i cannot be None
if length_pos_shape_i != length_pos:
return False
elif idx.stop is None:
pass
else:
return False
elif isinstance(node.op, AdvancedSubtensor1):
# get length of the indexed tensor along the first axis
try:
length = get_scalar_constant_value(
shape_of[node.inputs[0]][0], only_process_constants=True
)
except NotScalarConstantError:
return False return False
# get index (which must be a vector by definition) for pos, idx in enumerate(cdata):
idx = node.inputs[1]
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for length_pos = shape_of[node.inputs[0]][pos]
# this optimization
if isinstance(idx, Constant): if isinstance(idx.stop, (int, np.integer)):
idx = idx.value length_pos_data = sys.maxsize
if len(idx) != length:
return False
if np.any(idx != np.arange(length)):
return False
elif idx.owner is not None and isinstance(idx.owner.op, ARange):
try: try:
start, stop, step = map( length_pos_data = get_scalar_constant_value(
lambda x: get_scalar_constant_value(x, only_process_constants=True), length_pos, only_process_constants=True
idx.owner.inputs,
) )
except NotScalarConstantError: except NotScalarConstantError:
return False pass
if start != 0: if idx.stop < length_pos_data:
return False return False
if stop != length: elif isinstance(idx.stop, Variable):
length_pos_shape_i = idx.stop
# length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent
# the same underlying variable.
if length_pos_shape_i.owner and isinstance(
length_pos_shape_i.owner.op, ScalarFromTensor
):
length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
elif length_pos.owner and isinstance(length_pos.owner.op, TensorFromScalar):
length_pos = length_pos.owner.inputs[0]
else:
# We did not find underlying variables of the same type
return False return False
if step != 1:
# The type can be different: int32 vs int64. length_pos
# should always be int64 as that is what the shape
# tracker keep. Subtensor accept any scalar int{8,16,32,64}
# as index type.
assert str(length_pos.type.dtype) == "int64"
assert str(length_pos_shape_i.type.dtype) in [
"int8",
"int16",
"int32",
"int64",
]
# length_pos_shape_i cannot be None
if length_pos_shape_i != length_pos:
return False return False
elif idx.stop is None:
continue
else: else:
return False return False
return [node.inputs[0]]
@register_canonicalize
@register_specialize
@local_optimizer([AdvancedSubtensor1])
def local_useless_AdvancedSubtensor1(fgraph, node):
"""Remove `AdvancedSubtensor1` if it takes the full input.
In the `AdvancedSubtensor1` case, the full input is taken when the indices
are equivalent to ``arange(0, input.shape[0], 1)`` using either an explicit
list/vector or the `ARange` `Op`.
"""
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"):
return
shape_of = fgraph.shape_feature.shape_of
# get length of the indexed tensor along the first axis
try:
length = get_scalar_constant_value(
shape_of[node.inputs[0]][0], only_process_constants=True
)
except NotScalarConstantError:
return False
# get index (which must be a vector by definition)
idx = node.inputs[1]
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
# this optimization
if isinstance(idx, Constant):
idx = idx.value
if len(idx) != length:
return False
if np.any(idx != np.arange(length)):
return False
elif idx.owner is not None and isinstance(idx.owner.op, ARange):
try:
start, stop, step = map(
lambda x: get_scalar_constant_value(x, only_process_constants=True),
idx.owner.inputs,
)
except NotScalarConstantError:
return False
if start != 0:
return False
if stop != length:
return False
if step != 1:
return False
else: else:
return False return False
# We don't need to copy over any stacktrace here,
# because previous stacktrace should suffice.
return [node.inputs[0]] return [node.inputs[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论