提交 49f83bc0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid canonicalization of slices when merging non-overlapping slices in `local_subtensor_merge`

上级 eba1befd
......@@ -370,74 +370,73 @@ def local_subtensor_merge(fgraph, node):
"""
from pytensor.scan.op import Scan
if isinstance(node.op, Subtensor):
u = node.inputs[0]
if u.owner and isinstance(u.owner.op, Subtensor):
# We can merge :)
# x actual tensor on which we are picking slices
x = u.owner.inputs[0]
# slices of the first applied subtensor
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
slices2 = get_idx_list(node.inputs, node.op.idx_list)
# Don't try to do the optimization on do-while scan outputs,
# as it will create a dependency on the shape of the outputs
if (
x.owner is not None
and isinstance(x.owner.op, Scan)
and x.owner.op.info.as_while
):
return None
u = node.inputs[0]
if not (u.owner is not None and isinstance(u.owner.op, Subtensor)):
return None
# Get the shapes of the vectors !
try:
# try not to introduce new shape into the graph
xshape = fgraph.shape_feature.shape_of[x]
ushape = fgraph.shape_feature.shape_of[u]
except AttributeError:
# Following the suggested use of shape_feature which should
# consider the case when the compilation mode doesn't
# include the ShapeFeature
xshape = x.shape
ushape = u.shape
merged_slices = []
pos_2 = 0
pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1]
if isinstance(slice1, slice):
merged_slices.append(
merge_two_slices(
fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2]
)
)
pos_2 += 1
else:
merged_slices.append(slice1)
pos_1 += 1
if pos_2 < len(slices2):
merged_slices += slices2[pos_2:]
else:
merged_slices += slices1[pos_1:]
# We can merge :)
# x actual tensor on which we are picking slices
x = u.owner.inputs[0]
# slices of the first applied subtensor
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
slices2 = get_idx_list(node.inputs, node.op.idx_list)
merged_slices = tuple(as_index_constant(s) for s in merged_slices)
subtens = Subtensor(merged_slices)
# Don't try to do the optimization on do-while scan outputs,
# as it will create a dependency on the shape of the outputs
if (
x.owner is not None
and isinstance(x.owner.op, Scan)
and x.owner.op.info.as_while
):
return None
sl_ins = get_slice_elements(
merged_slices, lambda x: isinstance(x, Variable)
# Get the shapes of the vectors !
try:
# try not to introduce new shape into the graph
xshape = fgraph.shape_feature.shape_of[x]
ushape = fgraph.shape_feature.shape_of[u]
except AttributeError:
# Following the suggested use of shape_feature which should
# consider the case when the compilation mode doesn't
# include the ShapeFeature
xshape = x.shape
ushape = u.shape
merged_slices = []
pos_2 = 0
pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1]
if isinstance(slice1, slice):
merged_slices.append(
merge_two_slices(
fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2]
)
)
# Do not call make_node for test_value
out = subtens(x, *sl_ins)
pos_2 += 1
else:
merged_slices.append(slice1)
pos_1 += 1
# Copy over previous output stacktrace
# and stacktrace from previous slicing operation.
# Why? Because, the merged slicing operation could have failed
# because of either of the two original slicing operations
orig_out = node.outputs[0]
copy_stack_trace([orig_out, node.inputs[0]], out)
return [out]
if pos_2 < len(slices2):
merged_slices += slices2[pos_2:]
else:
merged_slices += slices1[pos_1:]
merged_slices = tuple(as_index_constant(s) for s in merged_slices)
subtens = Subtensor(merged_slices)
sl_ins = get_slice_elements(merged_slices, lambda x: isinstance(x, Variable))
# Do not call make_node for test_value
out = subtens(x, *sl_ins)
# Copy over previous output stacktrace
# and stacktrace from previous slicing operation.
# Why? Because, the merged slicing operation could have failed
# because of either of the two original slicing operations
orig_out = node.outputs[0]
copy_stack_trace([orig_out, node.inputs[0]], out)
return [out]
@register_specialize
......@@ -826,6 +825,12 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
if not isinstance(slice1, slice):
raise ValueError("slice1 should be of type `slice`")
# Simple case where one of the slices is useless
if is_full_slice(slice1):
return slice2
elif is_full_slice(slice2):
return slice1
sl1, reverse1 = get_canonical_form_slice(slice1, len1)
sl2, reverse2 = get_canonical_form_slice(slice2, len2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论