提交 49f83bc0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid canonicalization of slices when merging non-overlapping slices in `local_subtensor_merge`

上级 eba1befd
...@@ -370,74 +370,73 @@ def local_subtensor_merge(fgraph, node): ...@@ -370,74 +370,73 @@ def local_subtensor_merge(fgraph, node):
""" """
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
if isinstance(node.op, Subtensor): u = node.inputs[0]
u = node.inputs[0] if not (u.owner is not None and isinstance(u.owner.op, Subtensor)):
if u.owner and isinstance(u.owner.op, Subtensor): return None
# We can merge :)
# x actual tensor on which we are picking slices
x = u.owner.inputs[0]
# slices of the first applied subtensor
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
slices2 = get_idx_list(node.inputs, node.op.idx_list)
# Don't try to do the optimization on do-while scan outputs,
# as it will create a dependency on the shape of the outputs
if (
x.owner is not None
and isinstance(x.owner.op, Scan)
and x.owner.op.info.as_while
):
return None
# Get the shapes of the vectors ! # We can merge :)
try: # x actual tensor on which we are picking slices
# try not to introduce new shape into the graph x = u.owner.inputs[0]
xshape = fgraph.shape_feature.shape_of[x] # slices of the first applied subtensor
ushape = fgraph.shape_feature.shape_of[u] slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
except AttributeError: slices2 = get_idx_list(node.inputs, node.op.idx_list)
# Following the suggested use of shape_feature which should
# consider the case when the compilation mode doesn't
# include the ShapeFeature
xshape = x.shape
ushape = u.shape
merged_slices = []
pos_2 = 0
pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1]
if isinstance(slice1, slice):
merged_slices.append(
merge_two_slices(
fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2]
)
)
pos_2 += 1
else:
merged_slices.append(slice1)
pos_1 += 1
if pos_2 < len(slices2):
merged_slices += slices2[pos_2:]
else:
merged_slices += slices1[pos_1:]
merged_slices = tuple(as_index_constant(s) for s in merged_slices) # Don't try to do the optimization on do-while scan outputs,
subtens = Subtensor(merged_slices) # as it will create a dependency on the shape of the outputs
if (
x.owner is not None
and isinstance(x.owner.op, Scan)
and x.owner.op.info.as_while
):
return None
sl_ins = get_slice_elements( # Get the shapes of the vectors !
merged_slices, lambda x: isinstance(x, Variable) try:
# try not to introduce new shape into the graph
xshape = fgraph.shape_feature.shape_of[x]
ushape = fgraph.shape_feature.shape_of[u]
except AttributeError:
# Following the suggested use of shape_feature which should
# consider the case when the compilation mode doesn't
# include the ShapeFeature
xshape = x.shape
ushape = u.shape
merged_slices = []
pos_2 = 0
pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1]
if isinstance(slice1, slice):
merged_slices.append(
merge_two_slices(
fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2]
)
) )
# Do not call make_node for test_value pos_2 += 1
out = subtens(x, *sl_ins) else:
merged_slices.append(slice1)
pos_1 += 1
# Copy over previous output stacktrace if pos_2 < len(slices2):
# and stacktrace from previous slicing operation. merged_slices += slices2[pos_2:]
# Why? Because, the merged slicing operation could have failed else:
# because of either of the two original slicing operations merged_slices += slices1[pos_1:]
orig_out = node.outputs[0]
copy_stack_trace([orig_out, node.inputs[0]], out) merged_slices = tuple(as_index_constant(s) for s in merged_slices)
return [out] subtens = Subtensor(merged_slices)
sl_ins = get_slice_elements(merged_slices, lambda x: isinstance(x, Variable))
# Do not call make_node for test_value
out = subtens(x, *sl_ins)
# Copy over previous output stacktrace
# and stacktrace from previous slicing operation.
# Why? Because, the merged slicing operation could have failed
# because of either of the two original slicing operations
orig_out = node.outputs[0]
copy_stack_trace([orig_out, node.inputs[0]], out)
return [out]
@register_specialize @register_specialize
...@@ -826,6 +825,12 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2): ...@@ -826,6 +825,12 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
if not isinstance(slice1, slice): if not isinstance(slice1, slice):
raise ValueError("slice1 should be of type `slice`") raise ValueError("slice1 should be of type `slice`")
# Simple case where one of the slices is useless
if is_full_slice(slice1):
return slice2
elif is_full_slice(slice2):
return slice1
sl1, reverse1 = get_canonical_form_slice(slice1, len1) sl1, reverse1 = get_canonical_form_slice(slice1, len1)
sl2, reverse2 = get_canonical_form_slice(slice2, len2) sl2, reverse2 = get_canonical_form_slice(slice2, len2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论