提交 26d81f81 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Respect introduction of ShapeFeature

上级 9bee61d5
...@@ -729,11 +729,8 @@ class ShapeFeature(Feature): ...@@ -729,11 +729,8 @@ class ShapeFeature(Feature):
class ShapeOptimizer(GraphRewriter): class ShapeOptimizer(GraphRewriter):
"""Rewriter that adds `ShapeFeature` as a feature.""" """Rewriter that adds `ShapeFeature` as a feature."""
def add_requirements(self, fgraph):
fgraph.attach_feature(ShapeFeature())
def apply(self, fgraph): def apply(self, fgraph):
pass fgraph.attach_feature(ShapeFeature())
class UnShapeOptimizer(GraphRewriter): class UnShapeOptimizer(GraphRewriter):
......
...@@ -24,6 +24,9 @@ class XTypeCastOp(TypeCastingOp): ...@@ -24,6 +24,9 @@ class XTypeCastOp(TypeCastingOp):
This is like a `ViewOp` but without the expectation the input and output have identical types. This is like a `ViewOp` but without the expectation the input and output have identical types.
""" """
def infer_shape(self, fgraph, node, input_shapes):
return input_shapes
class TensorFromXTensor(XTypeCastOp): class TensorFromXTensor(XTypeCastOp):
__props__ = () __props__ = ()
......
...@@ -17,7 +17,7 @@ optdb.register( ...@@ -17,7 +17,7 @@ optdb.register(
"fast_run", "fast_run",
"fast_compile", "fast_compile",
"minimum_compile", "minimum_compile",
position=0.1, position=0.09, # before ShapeOpt, so we don't accidentally reintroduce xtensor Ops
) )
# Register OFG inline again after lowering xtensor # Register OFG inline again after lowering xtensor
...@@ -26,7 +26,7 @@ optdb.register( ...@@ -26,7 +26,7 @@ optdb.register(
dfs_rewriter(inline_ofg_expansion), dfs_rewriter(inline_ofg_expansion),
"fast_run", "fast_run",
"fast_compile", "fast_compile",
position=0.11, position=0.091,
) )
......
import numpy as np
from pytensor import function
from pytensor.xtensor.basic import Rename
from pytensor.xtensor.type import xtensor
def test_shape_feature_does_not_see_xop():
CALLED = False
x = xtensor("x", dims=("a",), dtype="int64")
class XOpWithBadInferShape(Rename):
def infer_shape(self, node, inputs, outputs):
global CALLED
CALLED = True
raise NotImplementedError()
test_xop = XOpWithBadInferShape(new_dims=("b",))
out = test_xop(x) - test_xop(x)
assert out.dims == ("b",)
fn = function([x], out)
np.testing.assert_allclose(fn([1, 2, 3]), [0, 0, 0])
assert not CALLED
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论