提交 7b7ab9e7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle no-op Subtensors in rewrites

上级 0088d03b
...@@ -336,35 +336,46 @@ def local_subtensor_of_dot(fgraph, node): ...@@ -336,35 +336,46 @@ def local_subtensor_of_dot(fgraph, node):
@node_rewriter([Subtensor]) @node_rewriter([Subtensor])
def local_useless_slice(fgraph, node): def local_useless_slice(fgraph, node):
""" """
Remove Subtensor of the form X[0, :] -> X[0] Remove Subtensor of the form:
1. X[0, :] -> X[0]
2. X[:] -> X
""" """
if isinstance(node.op, Subtensor): idxs = get_idx_list(node.inputs, node.op.idx_list)
slices = get_idx_list(node.inputs, node.op.idx_list)
last_slice = len(slices) if not idxs:
for s in slices[::-1]: return [node.inputs[0]]
# check if slice and then check slice indices
if ( last_useless_slice = len(idxs)
isinstance(s, slice) for s in idxs[::-1]:
and s.start is None # check if slice and then check slice indices
and s.stop is None if (
and ( isinstance(s, slice)
s.step is None and s.start is None
or extract_constant(s.step, only_process_constants=True) == 1 and s.stop is None
) and (
): s.step is None
last_slice -= 1 or extract_constant(s.step, only_process_constants=True) == 1
else: )
break ):
# check if we removed something last_useless_slice -= 1
if last_slice < len(slices): else:
subtens = Subtensor(slices[:last_slice]) break
sl_ins = get_slice_elements( # check if we removed something
slices[:last_slice], lambda x: isinstance(x, Variable) if last_useless_slice < len(idxs):
new_idxs = idxs[:last_useless_slice]
if new_idxs:
new_subtensor = Subtensor(new_idxs)
new_subtensor_inputs = get_slice_elements(
new_idxs, lambda x: isinstance(x, Variable)
) )
out = subtens(node.inputs[0], *sl_ins) out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
# Copy over previous output stacktrace # Copy over previous output stacktrace
copy_stack_trace(node.outputs, out) copy_stack_trace(node.outputs, out)
return [out] return [out]
else:
# Subtensor is not needed at all
return [node.inputs[0]]
# fast_compile to allow opt subtensor(cast{float32}(make_vector)) # fast_compile to allow opt subtensor(cast{float32}(make_vector))
...@@ -747,7 +758,13 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -747,7 +758,13 @@ def local_subtensor_make_vector(fgraph, node):
make_vector_op = x.owner.op make_vector_op = x.owner.op
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
(idx,) = node.op.idx_list idxs = node.op.idx_list
# Subtensor has no indexes, return make_vector
if not idxs:
return [x]
(idx,) = idxs
if isinstance(idx, (aes.ScalarType, TensorType)): if isinstance(idx, (aes.ScalarType, TensorType)):
old_idx, idx = idx, node.inputs[1] old_idx, idx = idx, node.inputs[1]
...@@ -903,7 +920,11 @@ def local_set_to_inc_subtensor(fgraph, node): ...@@ -903,7 +920,11 @@ def local_set_to_inc_subtensor(fgraph, node):
@node_rewriter([Subtensor]) @node_rewriter([Subtensor])
def local_useless_subtensor(fgraph, node): def local_useless_subtensor(fgraph, node):
"""Remove `Subtensor` if it takes the full input.""" """Remove `Subtensor` if it takes the full input."""
# This optimization needs ShapeOpt and fgraph.shape_feature
if not node.op.idx_list:
return [node.inputs[0]]
# The more elaborate optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"): if not hasattr(fgraph, "shape_feature"):
return return
......
...@@ -9,6 +9,7 @@ from pytensor.compile.function import function ...@@ -9,6 +9,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable, ancestors from pytensor.graph.basic import Constant, Variable, ancestors
from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
...@@ -21,6 +22,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise ...@@ -21,6 +22,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, add, dot, exp, sqr from pytensor.tensor.math import Dot, add, dot, exp, sqr
from pytensor.tensor.rewriting.subtensor import ( from pytensor.tensor.rewriting.subtensor import (
local_replace_AdvancedSubtensor, local_replace_AdvancedSubtensor,
local_subtensor_make_vector,
local_subtensor_shape_constant, local_subtensor_shape_constant,
) )
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
...@@ -764,6 +766,17 @@ class TestLocalSubtensorMakeVector: ...@@ -764,6 +766,17 @@ class TestLocalSubtensorMakeVector:
f = function([x, y, z], v_subtensor, mode=mode) f = function([x, y, z], v_subtensor, mode=mode)
assert check_stack_trace(f, ops_to_check="all") assert check_stack_trace(f, ops_to_check="all")
def test_empty_subtensor(self):
x, y = lscalars("xy")
v = make_vector(x, y)
out = v[()]
fgraph = FunctionGraph(outputs=[out], clone=False)
node = fgraph.outputs[0].owner
assert isinstance(node.op, Subtensor)
assert local_subtensor_make_vector.transform(fgraph, node) == [v]
class TestLocalSubtensorLift: class TestLocalSubtensorLift:
def test_basic(self): def test_basic(self):
......
...@@ -389,7 +389,8 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -389,7 +389,8 @@ class TestSubtensor(utt.OptimizationTestMixin):
t = Subtensor([])(n) t = Subtensor([])(n)
assert isinstance(t.owner.op, Subtensor) assert isinstance(t.owner.op, Subtensor)
self.eval_output_and_check( self.eval_output_and_check(
t, mode=self.mode.excluding("local_useless_subtensor") t,
mode=self.mode.excluding("local_useless_subtensor", "local_useless_slice"),
) )
def test_err_invalid_2(self): def test_err_invalid_2(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论