Unverified 提交 fa0ab9de authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Canonicalize Subtensor slices (#761)

上级 117f80da
...@@ -337,6 +337,7 @@ def local_subtensor_of_dot(fgraph, node): ...@@ -337,6 +337,7 @@ def local_subtensor_of_dot(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@register_stabilize
@node_rewriter([Subtensor]) @node_rewriter([Subtensor])
def local_useless_slice(fgraph, node): def local_useless_slice(fgraph, node):
""" """
...@@ -344,42 +345,64 @@ def local_useless_slice(fgraph, node): ...@@ -344,42 +345,64 @@ def local_useless_slice(fgraph, node):
1. X[0, :] -> X[0] 1. X[0, :] -> X[0]
2. X[:] -> X 2. X[:] -> X
Also, rewrite Subtensor of the form:
X[0:7:1] -> X[None:None:None]
where X is a vector of length 7
""" """
idxs = get_idx_list(node.inputs, node.op.idx_list) idxs = get_idx_list(node.inputs, node.op.idx_list)
x = node.inputs[0]
if not idxs: if not idxs:
return [node.inputs[0]] return [node.inputs[0]]
last_useless_slice = len(idxs) new_idxs = list(idxs)
for s in idxs[::-1]: change_flag = False
# check if slice and then check slice indices last_useful_idx = -1
for dim, s in enumerate(new_idxs):
if not isinstance(s, slice):
last_useful_idx = dim
continue
if s == slice(None):
continue
start = s.start
stop = s.stop
step = s.step
if ( if (
isinstance(s, slice) start is not None
and s.start is None and extract_constant(start, only_process_constants=True) == 0
and s.stop is None
and (
s.step is None
or extract_constant(s.step, only_process_constants=True) == 1
)
): ):
last_useless_slice -= 1 change_flag = True
else: start = None
break
# check if we removed something if (
if last_useless_slice < len(idxs): stop is not None
new_idxs = idxs[:last_useless_slice] and x.type.shape[dim] is not None
if new_idxs: and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
new_subtensor = Subtensor(new_idxs) ):
new_subtensor_inputs = get_slice_elements( change_flag = True
new_idxs, lambda x: isinstance(x, Variable) stop = None
)
out = new_subtensor(node.inputs[0], *new_subtensor_inputs) if (
# Copy over previous output stacktrace step is not None
copy_stack_trace(node.outputs, out) and extract_constant(step, only_process_constants=True) == 1
return [out] ):
else: change_flag = True
# Subtensor is not needed at all step = None
return [node.inputs[0]]
if not (start is None and stop is None and step is None):
last_useful_idx = dim
new_idxs[dim] = slice(start, stop, step)
if change_flag or ((last_useful_idx + 1) < len(idxs)):
out = x[tuple(new_idxs[: last_useful_idx + 1])]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)
return [out]
# fast_compile to allow opt subtensor(cast{float32}(make_vector)) # fast_compile to allow opt subtensor(cast{float32}(make_vector))
......
...@@ -10,7 +10,7 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode ...@@ -10,7 +10,7 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph, vectorize_graph from pytensor.graph import FunctionGraph, vectorize_graph
from pytensor.graph.basic import Constant, Variable, ancestors from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
...@@ -2402,3 +2402,44 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc): ...@@ -2402,3 +2402,44 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
else: else:
expected_out[:, :, core_idxs] += test_y expected_out[:, :, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out) np.testing.assert_allclose(fn(test_x, test_y), expected_out)
def test_slice_canonicalize():
rng = np.random.default_rng(43)
x = tensor(shape=(3, 5, None, 9))
test_x = rng.normal(size=(3, 5, 8, 9))
# Test case 1
y = x[0:None, 0:5, 0:7, 0:9:1]
f = pytensor.function([x], y, allow_input_downcast=True)
# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y = f.maker.fgraph.outputs[0].owner.inputs[0]
assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
expected_y = x[None:None:None, None:None:None, None:7:None]
assert equal_computations([test_y], [expected_y])
np.testing.assert_allclose(
f(test_x),
test_x[
0:None, 0:5, 0:7, 0:9:1
], # Use the unoptimized slice to make sure our rewrite logic is correct
)
# Test case 2
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1]
f1 = pytensor.function([x], y1, allow_input_downcast=True)
# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0]
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
assert equal_computations([test_y1], [expected_y1])
np.testing.assert_allclose(
f1(test_x),
test_x[0:-1, 0:5, 0:7, 0:-1:-1],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论