提交 7b7ab9e7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle no-op Subtensors in rewrites

上级 0088d03b
......@@ -336,35 +336,46 @@ def local_subtensor_of_dot(fgraph, node):
@node_rewriter([Subtensor])
def local_useless_slice(fgraph, node):
"""
Remove Subtensor of the form X[0, :] -> X[0]
Remove Subtensor of the form:
1. X[0, :] -> X[0]
2. X[:] -> X
"""
if isinstance(node.op, Subtensor):
slices = get_idx_list(node.inputs, node.op.idx_list)
last_slice = len(slices)
for s in slices[::-1]:
# check if slice and then check slice indices
if (
isinstance(s, slice)
and s.start is None
and s.stop is None
and (
s.step is None
or extract_constant(s.step, only_process_constants=True) == 1
)
):
last_slice -= 1
else:
break
# check if we removed something
if last_slice < len(slices):
subtens = Subtensor(slices[:last_slice])
sl_ins = get_slice_elements(
slices[:last_slice], lambda x: isinstance(x, Variable)
idxs = get_idx_list(node.inputs, node.op.idx_list)
if not idxs:
return [node.inputs[0]]
last_useless_slice = len(idxs)
for s in idxs[::-1]:
# check if slice and then check slice indices
if (
isinstance(s, slice)
and s.start is None
and s.stop is None
and (
s.step is None
or extract_constant(s.step, only_process_constants=True) == 1
)
):
last_useless_slice -= 1
else:
break
# check if we removed something
if last_useless_slice < len(idxs):
new_idxs = idxs[:last_useless_slice]
if new_idxs:
new_subtensor = Subtensor(new_idxs)
new_subtensor_inputs = get_slice_elements(
new_idxs, lambda x: isinstance(x, Variable)
)
out = subtens(node.inputs[0], *sl_ins)
out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)
return [out]
else:
# Subtensor is not needed at all
return [node.inputs[0]]
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
......@@ -747,7 +758,13 @@ def local_subtensor_make_vector(fgraph, node):
make_vector_op = x.owner.op
if isinstance(node.op, Subtensor):
(idx,) = node.op.idx_list
idxs = node.op.idx_list
# Subtensor has no indexes, return make_vector
if not idxs:
return [x]
(idx,) = idxs
if isinstance(idx, (aes.ScalarType, TensorType)):
old_idx, idx = idx, node.inputs[1]
......@@ -903,7 +920,11 @@ def local_set_to_inc_subtensor(fgraph, node):
@node_rewriter([Subtensor])
def local_useless_subtensor(fgraph, node):
"""Remove `Subtensor` if it takes the full input."""
# This optimization needs ShapeOpt and fgraph.shape_feature
if not node.op.idx_list:
return [node.inputs[0]]
# The more elaborate optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"):
return
......
......@@ -9,6 +9,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable, ancestors
from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
......@@ -21,6 +22,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, add, dot, exp, sqr
from pytensor.tensor.rewriting.subtensor import (
local_replace_AdvancedSubtensor,
local_subtensor_make_vector,
local_subtensor_shape_constant,
)
from pytensor.tensor.shape import (
......@@ -764,6 +766,17 @@ class TestLocalSubtensorMakeVector:
f = function([x, y, z], v_subtensor, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
def test_empty_subtensor(self):
x, y = lscalars("xy")
v = make_vector(x, y)
out = v[()]
fgraph = FunctionGraph(outputs=[out], clone=False)
node = fgraph.outputs[0].owner
assert isinstance(node.op, Subtensor)
assert local_subtensor_make_vector.transform(fgraph, node) == [v]
class TestLocalSubtensorLift:
def test_basic(self):
......
......@@ -389,7 +389,8 @@ class TestSubtensor(utt.OptimizationTestMixin):
t = Subtensor([])(n)
assert isinstance(t.owner.op, Subtensor)
self.eval_output_and_check(
t, mode=self.mode.excluding("local_useless_subtensor")
t,
mode=self.mode.excluding("local_useless_subtensor", "local_useless_slice"),
)
def test_err_invalid_2(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论