提交 26d81f81 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Respect introduction of ShapeFeature

上级 9bee61d5
......@@ -729,11 +729,8 @@ class ShapeFeature(Feature):
class ShapeOptimizer(GraphRewriter):
"""Rewriter that adds `ShapeFeature` as a feature."""
def add_requirements(self, fgraph):
fgraph.attach_feature(ShapeFeature())
def apply(self, fgraph):
pass
fgraph.attach_feature(ShapeFeature())
class UnShapeOptimizer(GraphRewriter):
......
......@@ -24,6 +24,9 @@ class XTypeCastOp(TypeCastingOp):
This is like a `ViewOp` but without the expectation the input and output have identical types.
"""
def infer_shape(self, fgraph, node, input_shapes):
return input_shapes
class TensorFromXTensor(XTypeCastOp):
__props__ = ()
......
......@@ -17,7 +17,7 @@ optdb.register(
"fast_run",
"fast_compile",
"minimum_compile",
position=0.1,
position=0.09, # before ShapeOpt, so we don't accidentally reintroduce xtensor Ops
)
# Register OFG inline again after lowering xtensor
......@@ -26,7 +26,7 @@ optdb.register(
dfs_rewriter(inline_ofg_expansion),
"fast_run",
"fast_compile",
position=0.11,
position=0.091,
)
......
import numpy as np
from pytensor import function
from pytensor.xtensor.basic import Rename
from pytensor.xtensor.type import xtensor
def test_shape_feature_does_not_see_xop():
CALLED = False
x = xtensor("x", dims=("a",), dtype="int64")
class XOpWithBadInferShape(Rename):
def infer_shape(self, node, inputs, outputs):
global CALLED
CALLED = True
raise NotImplementedError()
test_xop = XOpWithBadInferShape(new_dims=("b",))
out = test_xop(x) - test_xop(x)
assert out.dims == ("b",)
fn = function([x], out)
np.testing.assert_allclose(fn([1, 2, 3]), [0, 0, 0])
assert not CALLED
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论