提交 cfa76f5d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle non-constant NoneTypeT variables

上级 4e4f237a
...@@ -385,7 +385,9 @@ class RandomVariable(RNGConsumerOp): ...@@ -385,7 +385,9 @@ class RandomVariable(RNGConsumerOp):
dist_params = explicit_expand_dims( dist_params = explicit_expand_dims(
dist_params, dist_params,
self.ndims_params, self.ndims_params,
size_length=None if NoneConst.equals(size) else get_vector_length(size), size_length=None
if isinstance(size.type, NoneTypeT)
else get_vector_length(size),
) )
inputs = (rng, size, *dist_params) inputs = (rng, size, *dist_params)
......
...@@ -9,7 +9,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -9,7 +9,7 @@ from pytensor.graph.rewriting.basic import (
dfs_rewriter, dfs_rewriter,
node_rewriter, node_rewriter,
) )
from pytensor.tensor import NoneConst, TensorVariable from pytensor.tensor import TensorVariable
from pytensor.tensor.basic import constant from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.extra_ops import broadcast_to
...@@ -20,7 +20,7 @@ from pytensor.tensor.subtensor import ( ...@@ -20,7 +20,7 @@ from pytensor.tensor.subtensor import (
AdvancedSubtensor, AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
Subtensor, Subtensor,
get_idx_list, indices_from_subtensor,
) )
from pytensor.tensor.type import integer_dtypes from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.tensor.type_other import NoneTypeT, SliceType
...@@ -237,14 +237,17 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -237,14 +237,17 @@ def local_subtensor_rv_lift(fgraph, node):
return False return False
# Parse indices # Parse indices
indices = get_idx_list(node.inputs, getattr(subtensor_op, "idx_list", None)) if isinstance(subtensor_op, Subtensor):
indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list)
else:
indices = node.inputs[1:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis). # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle # If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite # and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if any( if any(
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx) is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT)
for idx in indices for idx in indices
): ):
return False return False
...@@ -267,7 +270,7 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -267,7 +270,7 @@ def local_subtensor_rv_lift(fgraph, node):
for idx in supp_indices: for idx in supp_indices:
if not ( if not (
isinstance(idx.type, SliceType) isinstance(idx.type, SliceType)
and all(NoneConst.equals(i) for i in idx.owner.inputs) and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
): ):
return False return False
n_discarded_idxs = len(supp_indices) n_discarded_idxs = len(supp_indices)
......
...@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING ...@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
import numpy as np import numpy as np
from pytensor.compile.sharedvalue import shared from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Variable
from pytensor.scalar import ScalarVariable from pytensor.scalar import ScalarVariable
from pytensor.tensor import NoneConst, get_vector_length from pytensor.tensor import NoneConst, get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast from pytensor.tensor.basic import as_tensor_variable, cast
...@@ -15,6 +15,7 @@ from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to ...@@ -15,6 +15,7 @@ from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
from pytensor.tensor.math import maximum from pytensor.tensor.math import maximum
from pytensor.tensor.shape import shape_padleft, specify_shape from pytensor.tensor.shape import shape_padleft, specify_shape
from pytensor.tensor.type import int_dtypes from pytensor.tensor.type import int_dtypes
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import faster_broadcast_to from pytensor.tensor.utils import faster_broadcast_to
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -178,24 +179,26 @@ def normalize_size_param( ...@@ -178,24 +179,26 @@ def normalize_size_param(
shape: int | np.ndarray | Variable | Sequence | None, shape: int | np.ndarray | Variable | Sequence | None,
) -> Variable: ) -> Variable:
"""Create an PyTensor value for a ``RandomVariable`` ``size`` parameter.""" """Create an PyTensor value for a ``RandomVariable`` ``size`` parameter."""
if shape is None or NoneConst.equals(shape): if shape is None:
return NoneConst return NoneConst
elif isinstance(shape, int): if isinstance(shape, Variable) and isinstance(shape.type, NoneTypeT):
return shape
if isinstance(shape, int):
shape = as_tensor_variable([shape], ndim=1) shape = as_tensor_variable([shape], ndim=1)
elif not isinstance(shape, np.ndarray | Variable | Sequence): else:
if not isinstance(shape, Sequence | Variable | np.ndarray):
raise TypeError( raise TypeError(
"Parameter size must be None, an integer, or a sequence with integers." "Parameter size must be None, an integer, or a sequence with integers."
) )
else:
shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64") shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64")
if not isinstance(shape, Constant): if shape.type.shape == (None,):
# This should help ensure that the length of non-constant `size`s # This should help ensure that the length of non-constant `size`s
# will be available after certain types of cloning (e.g. the kind # will be available after certain types of cloning (e.g. the kind `Scan` performs)
# `Scan` performs)
shape = specify_shape(shape, (get_vector_length(shape),)) shape = specify_shape(shape, (get_vector_length(shape),))
assert not any(s is None for s in shape.type.shape) assert shape.type.shape != (None,)
assert shape.dtype in int_dtypes assert shape.dtype in int_dtypes
return shape return shape
......
...@@ -47,7 +47,7 @@ from pytensor.tensor.shape import ( ...@@ -47,7 +47,7 @@ from pytensor.tensor.shape import (
) )
from pytensor.tensor.subtensor import Subtensor, get_idx_list from pytensor.tensor.subtensor import Subtensor, get_idx_list
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
from pytensor.tensor.type_other import NoneConst, NoneTypeT from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -1137,7 +1137,7 @@ def local_merge_consecutive_specify_shape(fgraph, node): ...@@ -1137,7 +1137,7 @@ def local_merge_consecutive_specify_shape(fgraph, node):
inner_obj, *shape = obj.owner.inputs inner_obj, *shape = obj.owner.inputs
for dim, sh in enumerate(node.inputs[1:]): for dim, sh in enumerate(node.inputs[1:]):
if not NoneConst.equals(sh): if not isinstance(sh.type, NoneTypeT):
shape[dim] = sh shape[dim] = sh
# TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are # TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are
...@@ -1183,7 +1183,7 @@ def local_Shape_of_SpecifyShape(fgraph, node): ...@@ -1183,7 +1183,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
# Replace `NoneConst` by `shape_i` # Replace `NoneConst` by `shape_i`
for i, sh in enumerate(shape): for i, sh in enumerate(shape):
if NoneConst.equals(sh): if isinstance(sh.type, NoneTypeT):
shape[i] = x.shape[i] shape[i] = x.shape[i]
return [stack(shape).astype(np.int64)] return [stack(shape).astype(np.int64)]
...@@ -1219,7 +1219,7 @@ def local_specify_shape_lift(fgraph, node): ...@@ -1219,7 +1219,7 @@ def local_specify_shape_lift(fgraph, node):
for i, (dim, bcast) in enumerate( for i, (dim, bcast) in enumerate(
zip(shape, out_broadcastable, strict=True) zip(shape, out_broadcastable, strict=True)
) )
if (not bcast and not NoneConst.equals(dim)) if (not bcast and not isinstance(dim.type, NoneTypeT))
} }
new_elem_inps = elem_inps.copy() new_elem_inps = elem_inps.copy()
for i, elem_inp in enumerate(elem_inps): for i, elem_inp in enumerate(elem_inps):
......
...@@ -408,7 +408,9 @@ class SpecifyShape(COp): ...@@ -408,7 +408,9 @@ class SpecifyShape(COp):
shape = tuple( shape = tuple(
NoneConst NoneConst
if (s is None or NoneConst.equals(s)) if (
s is None or (isinstance(s, Variable) and isinstance(s.type, NoneTypeT))
)
else ptb.as_tensor_variable(s, ndim=0) else ptb.as_tensor_variable(s, ndim=0)
for s in shape for s in shape
) )
...@@ -506,7 +508,7 @@ class SpecifyShape(COp): ...@@ -506,7 +508,7 @@ class SpecifyShape(COp):
for i, (shp_name, shp) in enumerate( for i, (shp_name, shp) in enumerate(
zip(shape_names, node.inputs[1:], strict=True) zip(shape_names, node.inputs[1:], strict=True)
): ):
if NoneConst.equals(shp): if isinstance(shp.type, NoneTypeT):
continue continue
code += dedent( code += dedent(
f""" f"""
...@@ -594,7 +596,10 @@ def _vectorize_specify_shape(op, node, x, *shape): ...@@ -594,7 +596,10 @@ def _vectorize_specify_shape(op, node, x, *shape):
if any( if any(
as_tensor_variable(dim).type.ndim != 0 as_tensor_variable(dim).type.ndim != 0
for dim in shape for dim in shape
if not (NoneConst.equals(dim) or dim is None) if not (
(isinstance(dim, Variable) and isinstance(dim.type, NoneTypeT))
or dim is None
)
): ):
raise NotImplementedError( raise NotImplementedError(
"It is not possible to vectorize the shape argument of SpecifyShape" "It is not possible to vectorize the shape argument of SpecifyShape"
......
...@@ -11,6 +11,7 @@ from pytensor.tensor.random.basic import NormalRV ...@@ -11,6 +11,7 @@ from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.random.op import RandomVariable, default_rng from pytensor.tensor.random.op import RandomVariable, default_rng
from pytensor.tensor.shape import specify_shape from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import iscalar, tensor from pytensor.tensor.type import iscalar, tensor
from pytensor.tensor.type_other import none_type_t
@pytest.fixture(scope="function", autouse=False) @pytest.fixture(scope="function", autouse=False)
...@@ -317,3 +318,12 @@ def test_size_none_vs_empty(): ...@@ -317,3 +318,12 @@ def test_size_none_vs_empty():
ValueError, match="Size length is incompatible with batched dimensions" ValueError, match="Size length is incompatible with batched dimensions"
): ):
rv([0], [1], size=()) rv([0], [1], size=())
def test_non_constant_none_size():
# Regression test for https://github.com/pymc-devs/pymc/issues/7901#issuecomment-3528479876
loc = pt.vector("loc", dtype="float64")
size = none_type_t("none_size")
rv = normal(loc, size=size)
rv.eval({loc: np.arange(5, dtype="float64"), size: None}, mode="FAST_COMPILE")
...@@ -7,9 +7,11 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -7,9 +7,11 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
RandomStream, RandomStream,
broadcast_params, broadcast_params,
normalize_size_param,
supp_shape_from_ref_param_shape, supp_shape_from_ref_param_shape,
) )
from pytensor.tensor.type import matrix, tensor from pytensor.tensor.type import TensorType, matrix, tensor
from pytensor.tensor.type_other import NoneTypeT, none_type_t
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -327,3 +329,22 @@ def test_supp_shape_from_ref_param_shape(): ...@@ -327,3 +329,22 @@ def test_supp_shape_from_ref_param_shape():
ref_param_idx=1, ref_param_idx=1,
) )
assert res == (3, 4) assert res == (3, 4)
def test_normalize_size_param():
assert normalize_size_param(None).type == NoneTypeT()
sym_none_size = none_type_t()
assert normalize_size_param(sym_none_size) is sym_none_size
empty_size = normalize_size_param(())
assert empty_size.type == TensorType(dtype="int64", shape=(0,))
int_size = normalize_size_param(5)
assert int_size.type == TensorType(dtype="int64", shape=(1,))
seq_int_size = normalize_size_param((5, 3, 4))
assert seq_int_size.type == TensorType(dtype="int64", shape=(3,))
sym_tensor_size = tensor(shape=(3,), dtype="int64")
assert normalize_size_param(sym_tensor_size) is sym_tensor_size
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论