提交 c55a97bb authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Make sure `infer_shape` can handle xtensor lowering

上级 22663d9f
......@@ -605,9 +605,9 @@ class ShapeFeature(Feature):
# 2) we are putting things back after a failed transaction.
# In case 1, if r has a shape_i client, we will want to
# replace the shape_i of r with the shape of new_r. Say that
# r is *scheduled*.
# replace the shape_i of r with the shape of new_r. Say that r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r
# This schedule is processed by `local_track_shape_i`.
for shpnode, idx in fgraph.clients[r] + [(node, i)]:
if isinstance(shpnode.op, Shape_i):
idx = shpnode.op.i
......@@ -1271,13 +1271,39 @@ def local_shape_to_shape_i(fgraph, node):
return [ret]
@register_infer_shape
@register_specialize
@register_canonicalize
@node_rewriter([Shape_i])
def local_track_shape_i(fgraph, node):
if not isinstance(node.op, Shape_i):
return False
"""
Update `Shape_i` nodes to match `ShapeFeature`'s internal state.
This rewrite is essential for propagating shape information during graph
transformations (like lowering). When a node is replaced or updated,
`ShapeFeature` calculates the shape of the new node and "schedules"
dependent `Shape_i` nodes for update, so they use the latest inferred graph.
If we start with an fgraph containing the two nodes below:
>> out = OpWithoutInferShape(a, b)
>> out_shape_i = Shape_i(out)
And then rewrite
>> new_out = OpWithInferShape(a, b)
>> fgraph.replace(out, new_out)
We end up with
>> out_shape_i == Shape_i(new_out)
If installed, ShapeFeature will do this work in the background
>> new_out_shape = infer_shape(new_out) # Usually some f(a, b)
>> fgraph.shape_feature.scheduled[out_shape_i.owner] = new_out_shape
And this rewrite will ultimately propagate the inference back to the fgraph
>> new_out_shape_i = fgraph.shape_feature.scheduled[out_shape_i.owner][i]
>> fgraph.replace(out_shape_i, new_out_shape_i)
"""
try:
shape_feature = fgraph.shape_feature
except AttributeError:
......
......@@ -4,6 +4,7 @@ from collections.abc import Sequence
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter, dfs_rewriter
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
from pytensor.tensor.basic import infer_shape_db
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.type import XTensorVariable
......@@ -11,6 +12,12 @@ from pytensor.xtensor.type import XTensorVariable
lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
infer_shape_db.register(
"lower_xtensor",
lower_xtensor_db,
"infer_shape",
)
optdb.register(
"lower_xtensor",
lower_xtensor_db,
......@@ -50,6 +57,7 @@ def register_lower_xtensor(
"fast_run",
"fast_compile",
"minimum_compile",
"infer_shape",
*tags,
**kwargs,
)
......
from pytensor.graph import FunctionGraph
from pytensor.tensor.basic import infer_shape_db
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import Shape_i
from pytensor.xtensor import xtensor
from tests.unittest_tools import assert_equal_computations
def test_infer_shape_db_handles_xtensor_lowering():
x = xtensor("x", dims=("a", "b"))
y = x.sum(dim="a")
shape_y = y.shape[0]
# Without ShapeFeature
fgraph = FunctionGraph([x], [shape_y], features=[], copy_inputs=False)
infer_shape_db.default_query.rewrite(fgraph)
[rewritten_shape_y] = fgraph.outputs
assert_equal_computations([rewritten_shape_y], [(x.values.sum(0)).shape[0]])
# With ShapeFeature
fgraph = FunctionGraph([x], [shape_y], features=[ShapeFeature()], copy_inputs=False)
infer_shape_db.default_query.rewrite(fgraph)
[rewritten_shape_y] = fgraph.outputs
assert_equal_computations([rewritten_shape_y], [Shape_i(1)(x)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论