提交 26657372 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Flag Ops whose output types depend on input values

These nodes must always be rebuilt in non-strict mode
上级 54f4b200
...@@ -266,14 +266,24 @@ class Apply(Node, Generic[OpType]): ...@@ -266,14 +266,24 @@ class Apply(Node, Generic[OpType]):
assert isinstance(inputs, (list, tuple)) assert isinstance(inputs, (list, tuple))
remake_node = False remake_node = False
new_inputs: List["Variable"] = list(inputs) new_inputs: List["Variable"] = list(inputs)
# Some Ops like Alloc require the node to always be rebuilt in non-strict mode
# as the output type depends on the input values and not just their types
output_type_depends_on_input_value = self.op._output_type_depends_on_input_value
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)): for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
if curr.type != new.type: # Check if the input type changed or if the Op has output types that depend on input values
if (curr.type != new.type) or output_type_depends_on_input_value:
# In strict mode, the cloned graph is assumed to be mathematically equivalent to the original one.
# We only need to rebuild a node when the new input has a different, but compatible, type.
# This can happen e.g., when we provide a new input with a more specialized static shape.
if strict: if strict:
new_i = curr.type.filter_variable(new) new_i = curr.type.filter_variable(new)
new_inputs[i] = new_i new_inputs[i] = new_i
if curr.type != new_i.type: if curr.type != new_i.type:
remake_node = True remake_node = True
# Otherwise, we always rebuild the node
else: else:
remake_node = True remake_node = True
......
...@@ -207,6 +207,15 @@ class Op(MetaObject): ...@@ -207,6 +207,15 @@ class Op(MetaObject):
otypes: Optional[Sequence["Type"]] = None otypes: Optional[Sequence["Type"]] = None
params_type: Optional[ParamsType] = None params_type: Optional[ParamsType] = None
_output_type_depends_on_input_value = False
"""
Whether the static output type depends on the inferred value of one of the inputs.
(e.g, via constant folding or static shape inference).
This information is needed when rebuilding a graph with new inputs,
as nodes with these Ops must be rebuilt even if the input types haven't changed.
"""
def make_node(self, *inputs: Variable) -> Apply: def make_node(self, *inputs: Variable) -> Apply:
"""Construct an `Apply` node that represent the application of this operation to the given inputs. """Construct an `Apply` node that represent the application of this operation to the given inputs.
......
...@@ -1418,6 +1418,8 @@ class Alloc(COp): ...@@ -1418,6 +1418,8 @@ class Alloc(COp):
""" """
_f16_ok = True _f16_ok = True
_output_type_depends_on_input_value = True
__props__ = () __props__ = ()
def make_node(self, value, *shape): def make_node(self, value, *shape):
...@@ -3819,6 +3821,8 @@ class Choose(Op): ...@@ -3819,6 +3821,8 @@ class Choose(Op):
class AllocEmpty(COp): class AllocEmpty(COp):
"""Implement Alloc on the cpu, but without initializing memory.""" """Implement Alloc on the cpu, but without initializing memory."""
_output_type_depends_on_input_value = True
__props__ = ("dtype",) __props__ = ("dtype",)
params_type = ParamsType(typecode=int32) params_type = ParamsType(typecode=int32)
......
...@@ -1561,6 +1561,8 @@ def broadcast_shape_iter( ...@@ -1561,6 +1561,8 @@ def broadcast_shape_iter(
class BroadcastTo(COp): class BroadcastTo(COp):
"""An `Op` for `numpy.broadcast_to`.""" """An `Op` for `numpy.broadcast_to`."""
_output_type_depends_on_input_value = True
__props__ = () __props__ = ()
view_map = {0: [0]} view_map = {0: [0]}
......
...@@ -91,6 +91,8 @@ class RandomVariable(Op): ...@@ -91,6 +91,8 @@ class RandomVariable(Op):
""" """
_output_type_depends_on_input_value = True
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace") __props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace")
default_output = 1 default_output = 1
......
...@@ -388,6 +388,7 @@ class SpecifyShape(COp): ...@@ -388,6 +388,7 @@ class SpecifyShape(COp):
view_map = {0: [0]} view_map = {0: [0]}
__props__ = () __props__ = ()
_f16_ok = True _f16_ok = True
_output_type_depends_on_input_value = True
def make_node(self, x, *shape): def make_node(self, x, *shape):
from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.basic import get_underlying_scalar_constant_value
...@@ -587,6 +588,7 @@ class Reshape(COp): ...@@ -587,6 +588,7 @@ class Reshape(COp):
view_map = {0: [0]} # output 0 is potentially aliased to inputs [0] view_map = {0: [0]} # output 0 is potentially aliased to inputs [0]
_f16_ok = True _f16_ok = True
_output_type_depends_on_input_value = True
check_input = False check_input = False
__props__ = ("ndim",) __props__ = ("ndim",)
......
...@@ -14,6 +14,7 @@ from pytensor.configdefaults import config ...@@ -14,6 +14,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, Variable, graph_inputs from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.basic import ( from pytensor.tensor.random.basic import (
bernoulli, bernoulli,
...@@ -57,7 +58,7 @@ from pytensor.tensor.random.basic import ( ...@@ -57,7 +58,7 @@ from pytensor.tensor.random.basic import (
weibull, weibull,
) )
from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.type import iscalar, scalar, tensor from pytensor.tensor.type import iscalar, scalar, tensor, vector
from tests.unittest_tools import create_pytensor_param from tests.unittest_tools import create_pytensor_param
...@@ -1422,3 +1423,19 @@ def test_pickle(): ...@@ -1422,3 +1423,19 @@ def test_pickle():
a_unpkl = pickle.loads(a_pkl) a_unpkl = pickle.loads(a_pkl)
assert a_unpkl.owner.op._props() == sample_a.owner.op._props() assert a_unpkl.owner.op._props() == sample_a.owner.op._props()
def test_rebuild():
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
y = normal(size=x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)
x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
...@@ -16,6 +16,7 @@ from pytensor.compile.ops import DeepCopyOp ...@@ -16,6 +16,7 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.gradient import grad, hessian from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar import autocast_float, autocast_float_as from pytensor.scalar import autocast_float, autocast_float_as
...@@ -818,6 +819,22 @@ class TestAlloc: ...@@ -818,6 +819,22 @@ class TestAlloc:
res = pytensor.function([], full_at, mode=self.mode)() res = pytensor.function([], full_at, mode=self.mode)()
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64")) assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))
@pytest.mark.parametrize("func", (at.zeros, at.empty))
def test_rebuild(self, func):
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
y = func(x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)
x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
def test_infer_shape(): def test_infer_shape():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
......
...@@ -9,6 +9,7 @@ from pytensor import tensor as at ...@@ -9,6 +9,7 @@ from pytensor import tensor as at
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, applys_between from pytensor.graph.basic import Constant, applys_between
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
...@@ -1399,6 +1400,22 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1399,6 +1400,22 @@ class TestBroadcastTo(utt.InferShapeTester):
assert advincsub_node.op.inplace is False assert advincsub_node.op.inplace is False
def test_rebuild(self):
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
i = 0
y = broadcast_to(i, x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)
x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
def test_broadcast_arrays(): def test_broadcast_arrays():
x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix() x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix()
......
...@@ -7,6 +7,7 @@ from pytensor.compile.ops import DeepCopyOp ...@@ -7,6 +7,7 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
...@@ -337,6 +338,21 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin): ...@@ -337,6 +338,21 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
Reshape, Reshape,
) )
def test_rebuild(self):
x = as_tensor_variable(50)
i = vector("i")
i_test = np.zeros((100,), dtype=config.floatX)
y = reshape(i, (100 // x, x))
assert y.type.shape == (2, 50)
assert tuple(y.shape.eval({i: i_test})) == (2, 50)
assert y.eval({i: i_test}).shape == (2, 50)
x_new = as_tensor_variable(25)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (4, 25)
assert tuple(y_new.shape.eval({i: i_test})) == (4, 25)
assert y_new.eval({i: i_test}).shape == (4, 25)
def test_shape_i_hash(): def test_shape_i_hash():
assert isinstance(Shape_i(np.int64(1)).__hash__(), int) assert isinstance(Shape_i(np.int64(1)).__hash__(), int)
...@@ -524,6 +540,22 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -524,6 +540,22 @@ class TestSpecifyShape(utt.InferShapeTester):
z_grad = grad(z.sum(), wrt=x) z_grad = grad(z.sum(), wrt=x)
assert isinstance(z_grad.owner.op, SpecifyShape) assert isinstance(z_grad.owner.op, SpecifyShape)
def test_rebuild(self):
x = as_tensor_variable(50)
i = matrix("i")
i_test = np.zeros((4, 50), dtype=config.floatX)
y = specify_shape(i, (None, x))
assert y.type.shape == (None, 50)
assert tuple(y.shape.eval({i: i_test})) == (4, 50)
assert y.eval({i: i_test}).shape == (4, 50)
x_new = as_tensor_variable(100)
i_test = np.zeros((4, 100), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (None, 100)
assert tuple(y_new.shape.eval({i: i_test})) == (4, 100)
assert y_new.eval({i: i_test}).shape == (4, 100)
class TestSpecifyBroadcastable: class TestSpecifyBroadcastable:
def test_basic(self): def test_basic(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论