提交 26657372 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Flag Ops whose output types depend on input values

These nodes must always be rebuilt in non-strict mode
上级 54f4b200
......@@ -266,14 +266,24 @@ class Apply(Node, Generic[OpType]):
assert isinstance(inputs, (list, tuple))
remake_node = False
new_inputs: List["Variable"] = list(inputs)
# Some Ops like Alloc require the node to always be rebuilt in non-strict mode
# as the output type depends on the input values and not just their types
output_type_depends_on_input_value = self.op._output_type_depends_on_input_value
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
if curr.type != new.type:
# Check if the input type changed or if the Op has output types that depend on input values
if (curr.type != new.type) or output_type_depends_on_input_value:
# In strict mode, the cloned graph is assumed to be mathematically equivalent to the original one.
# We only need to rebuild a node when the new input has a different, but compatible, type.
# This can happen e.g., when we provide a new input with a more specialized static shape.
if strict:
new_i = curr.type.filter_variable(new)
new_inputs[i] = new_i
if curr.type != new_i.type:
remake_node = True
# Otherwise, we always rebuild the node
else:
remake_node = True
......
......@@ -207,6 +207,15 @@ class Op(MetaObject):
otypes: Optional[Sequence["Type"]] = None
params_type: Optional[ParamsType] = None
_output_type_depends_on_input_value = False
"""
Whether the static output type depends on the inferred value of one of the inputs.
(e.g, via constant folding or static shape inference).
This information is needed when rebuilding a graph with new inputs,
as nodes with these Ops must be rebuilt even if the input types haven't changed.
"""
def make_node(self, *inputs: Variable) -> Apply:
"""Construct an `Apply` node that represent the application of this operation to the given inputs.
......
......@@ -1418,6 +1418,8 @@ class Alloc(COp):
"""
_f16_ok = True
_output_type_depends_on_input_value = True
__props__ = ()
def make_node(self, value, *shape):
......@@ -3819,6 +3821,8 @@ class Choose(Op):
class AllocEmpty(COp):
"""Implement Alloc on the cpu, but without initializing memory."""
_output_type_depends_on_input_value = True
__props__ = ("dtype",)
params_type = ParamsType(typecode=int32)
......
......@@ -1561,6 +1561,8 @@ def broadcast_shape_iter(
class BroadcastTo(COp):
"""An `Op` for `numpy.broadcast_to`."""
_output_type_depends_on_input_value = True
__props__ = ()
view_map = {0: [0]}
......
......@@ -91,6 +91,8 @@ class RandomVariable(Op):
"""
_output_type_depends_on_input_value = True
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace")
default_output = 1
......
......@@ -388,6 +388,7 @@ class SpecifyShape(COp):
view_map = {0: [0]}
__props__ = ()
_f16_ok = True
_output_type_depends_on_input_value = True
def make_node(self, x, *shape):
from pytensor.tensor.basic import get_underlying_scalar_constant_value
......@@ -587,6 +588,7 @@ class Reshape(COp):
view_map = {0: [0]} # output 0 is potentially aliased to inputs [0]
_f16_ok = True
_output_type_depends_on_input_value = True
check_input = False
__props__ = ("ndim",)
......
......@@ -14,6 +14,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.basic import (
bernoulli,
......@@ -57,7 +58,7 @@ from pytensor.tensor.random.basic import (
weibull,
)
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.type import iscalar, scalar, tensor
from pytensor.tensor.type import iscalar, scalar, tensor, vector
from tests.unittest_tools import create_pytensor_param
......@@ -1422,3 +1423,19 @@ def test_pickle():
a_unpkl = pickle.loads(a_pkl)
assert a_unpkl.owner.op._props() == sample_a.owner.op._props()
def test_rebuild():
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
y = normal(size=x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)
x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
......@@ -16,6 +16,7 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert
from pytensor.scalar import autocast_float, autocast_float_as
......@@ -818,6 +819,22 @@ class TestAlloc:
res = pytensor.function([], full_at, mode=self.mode)()
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))
@pytest.mark.parametrize("func", (at.zeros, at.empty))
def test_rebuild(self, func):
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
y = func(x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)
x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
def test_infer_shape():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
......
......@@ -9,6 +9,7 @@ from pytensor import tensor as at
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, applys_between
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.raise_op import Assert
from pytensor.tensor.elemwise import DimShuffle
......@@ -1399,6 +1400,22 @@ class TestBroadcastTo(utt.InferShapeTester):
assert advincsub_node.op.inplace is False
def test_rebuild(self):
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
i = 0
y = broadcast_to(i, x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)
x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
def test_broadcast_arrays():
x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix()
......
......@@ -7,6 +7,7 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar.basic import ScalarConstant
......@@ -337,6 +338,21 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
Reshape,
)
def test_rebuild(self):
x = as_tensor_variable(50)
i = vector("i")
i_test = np.zeros((100,), dtype=config.floatX)
y = reshape(i, (100 // x, x))
assert y.type.shape == (2, 50)
assert tuple(y.shape.eval({i: i_test})) == (2, 50)
assert y.eval({i: i_test}).shape == (2, 50)
x_new = as_tensor_variable(25)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (4, 25)
assert tuple(y_new.shape.eval({i: i_test})) == (4, 25)
assert y_new.eval({i: i_test}).shape == (4, 25)
def test_shape_i_hash():
assert isinstance(Shape_i(np.int64(1)).__hash__(), int)
......@@ -524,6 +540,22 @@ class TestSpecifyShape(utt.InferShapeTester):
z_grad = grad(z.sum(), wrt=x)
assert isinstance(z_grad.owner.op, SpecifyShape)
def test_rebuild(self):
x = as_tensor_variable(50)
i = matrix("i")
i_test = np.zeros((4, 50), dtype=config.floatX)
y = specify_shape(i, (None, x))
assert y.type.shape == (None, 50)
assert tuple(y.shape.eval({i: i_test})) == (4, 50)
assert y.eval({i: i_test}).shape == (4, 50)
x_new = as_tensor_variable(100)
i_test = np.zeros((4, 100), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (None, 100)
assert tuple(y_new.shape.eval({i: i_test})) == (4, 100)
assert y_new.eval({i: i_test}).shape == (4, 100)
class TestSpecifyBroadcastable:
def test_basic(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论