提交 49f83bc0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid canonicalization of slices when merging non-overlapping slices in `local_subtensor_merge`

上级 eba1befd
......@@ -370,9 +370,10 @@ def local_subtensor_merge(fgraph, node):
"""
from pytensor.scan.op import Scan
if isinstance(node.op, Subtensor):
u = node.inputs[0]
if u.owner and isinstance(u.owner.op, Subtensor):
if not (u.owner is not None and isinstance(u.owner.op, Subtensor)):
return None
# We can merge :)
# x actual tensor on which we are picking slices
x = u.owner.inputs[0]
......@@ -425,9 +426,7 @@ def local_subtensor_merge(fgraph, node):
merged_slices = tuple(as_index_constant(s) for s in merged_slices)
subtens = Subtensor(merged_slices)
sl_ins = get_slice_elements(
merged_slices, lambda x: isinstance(x, Variable)
)
sl_ins = get_slice_elements(merged_slices, lambda x: isinstance(x, Variable))
# Do not call make_node for test_value
out = subtens(x, *sl_ins)
......@@ -826,6 +825,12 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
if not isinstance(slice1, slice):
raise ValueError("slice1 should be of type `slice`")
# Simple case where one of the slices is useless
if is_full_slice(slice1):
return slice2
elif is_full_slice(slice2):
return slice1
sl1, reverse1 = get_canonical_form_slice(slice1, len1)
sl2, reverse2 = get_canonical_form_slice(slice2, len2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论