提交 cfa76f5d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle non-constant NoneTypeT variables

上级 4e4f237a
......@@ -385,7 +385,9 @@ class RandomVariable(RNGConsumerOp):
dist_params = explicit_expand_dims(
dist_params,
self.ndims_params,
size_length=None if NoneConst.equals(size) else get_vector_length(size),
size_length=None
if isinstance(size.type, NoneTypeT)
else get_vector_length(size),
)
inputs = (rng, size, *dist_params)
......
......@@ -9,7 +9,7 @@ from pytensor.graph.rewriting.basic import (
dfs_rewriter,
node_rewriter,
)
from pytensor.tensor import NoneConst, TensorVariable
from pytensor.tensor import TensorVariable
from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to
......@@ -20,7 +20,7 @@ from pytensor.tensor.subtensor import (
AdvancedSubtensor,
AdvancedSubtensor1,
Subtensor,
get_idx_list,
indices_from_subtensor,
)
from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceType
......@@ -237,17 +237,20 @@ def local_subtensor_rv_lift(fgraph, node):
return False
# Parse indices
indices = get_idx_list(node.inputs, getattr(subtensor_op, "idx_list", None))
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
if any(
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx)
for idx in indices
):
return False
if isinstance(subtensor_op, Subtensor):
indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list)
else:
indices = node.inputs[1:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if any(
is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT)
for idx in indices
):
return False
# Check that indexing does not act on support dims
batch_ndims = rv_op.batch_ndim(rv_node)
......@@ -267,7 +270,7 @@ def local_subtensor_rv_lift(fgraph, node):
for idx in supp_indices:
if not (
isinstance(idx.type, SliceType)
and all(NoneConst.equals(i) for i in idx.owner.inputs)
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
):
return False
n_discarded_idxs = len(supp_indices)
......
......@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
import numpy as np
from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.basic import Variable
from pytensor.scalar import ScalarVariable
from pytensor.tensor import NoneConst, get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast
......@@ -15,6 +15,7 @@ from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
from pytensor.tensor.math import maximum
from pytensor.tensor.shape import shape_padleft, specify_shape
from pytensor.tensor.type import int_dtypes
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import faster_broadcast_to
from pytensor.tensor.variable import TensorVariable
......@@ -178,24 +179,26 @@ def normalize_size_param(
shape: int | np.ndarray | Variable | Sequence | None,
) -> Variable:
"""Create an PyTensor value for a ``RandomVariable`` ``size`` parameter."""
if shape is None or NoneConst.equals(shape):
if shape is None:
return NoneConst
elif isinstance(shape, int):
if isinstance(shape, Variable) and isinstance(shape.type, NoneTypeT):
return shape
if isinstance(shape, int):
shape = as_tensor_variable([shape], ndim=1)
elif not isinstance(shape, np.ndarray | Variable | Sequence):
raise TypeError(
"Parameter size must be None, an integer, or a sequence with integers."
)
else:
if not isinstance(shape, Sequence | Variable | np.ndarray):
raise TypeError(
"Parameter size must be None, an integer, or a sequence with integers."
)
shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64")
if not isinstance(shape, Constant):
if shape.type.shape == (None,):
# This should help ensure that the length of non-constant `size`s
# will be available after certain types of cloning (e.g. the kind
# `Scan` performs)
# will be available after certain types of cloning (e.g. the kind `Scan` performs)
shape = specify_shape(shape, (get_vector_length(shape),))
assert not any(s is None for s in shape.type.shape)
assert shape.type.shape != (None,)
assert shape.dtype in int_dtypes
return shape
......
......@@ -47,7 +47,7 @@ from pytensor.tensor.shape import (
)
from pytensor.tensor.subtensor import Subtensor, get_idx_list
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
from pytensor.tensor.type_other import NoneConst, NoneTypeT
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import TensorVariable
......@@ -1137,7 +1137,7 @@ def local_merge_consecutive_specify_shape(fgraph, node):
inner_obj, *shape = obj.owner.inputs
for dim, sh in enumerate(node.inputs[1:]):
if not NoneConst.equals(sh):
if not isinstance(sh.type, NoneTypeT):
shape[dim] = sh
# TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are
......@@ -1183,7 +1183,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
# Replace `NoneConst` by `shape_i`
for i, sh in enumerate(shape):
if NoneConst.equals(sh):
if isinstance(sh.type, NoneTypeT):
shape[i] = x.shape[i]
return [stack(shape).astype(np.int64)]
......@@ -1219,7 +1219,7 @@ def local_specify_shape_lift(fgraph, node):
for i, (dim, bcast) in enumerate(
zip(shape, out_broadcastable, strict=True)
)
if (not bcast and not NoneConst.equals(dim))
if (not bcast and not isinstance(dim.type, NoneTypeT))
}
new_elem_inps = elem_inps.copy()
for i, elem_inp in enumerate(elem_inps):
......
......@@ -408,7 +408,9 @@ class SpecifyShape(COp):
shape = tuple(
NoneConst
if (s is None or NoneConst.equals(s))
if (
s is None or (isinstance(s, Variable) and isinstance(s.type, NoneTypeT))
)
else ptb.as_tensor_variable(s, ndim=0)
for s in shape
)
......@@ -506,7 +508,7 @@ class SpecifyShape(COp):
for i, (shp_name, shp) in enumerate(
zip(shape_names, node.inputs[1:], strict=True)
):
if NoneConst.equals(shp):
if isinstance(shp.type, NoneTypeT):
continue
code += dedent(
f"""
......@@ -594,7 +596,10 @@ def _vectorize_specify_shape(op, node, x, *shape):
if any(
as_tensor_variable(dim).type.ndim != 0
for dim in shape
if not (NoneConst.equals(dim) or dim is None)
if not (
(isinstance(dim, Variable) and isinstance(dim.type, NoneTypeT))
or dim is None
)
):
raise NotImplementedError(
"It is not possible to vectorize the shape argument of SpecifyShape"
......
......@@ -11,6 +11,7 @@ from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.random.op import RandomVariable, default_rng
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import iscalar, tensor
from pytensor.tensor.type_other import none_type_t
@pytest.fixture(scope="function", autouse=False)
......@@ -317,3 +318,12 @@ def test_size_none_vs_empty():
ValueError, match="Size length is incompatible with batched dimensions"
):
rv([0], [1], size=())
def test_non_constant_none_size():
# Regression test for https://github.com/pymc-devs/pymc/issues/7901#issuecomment-3528479876
loc = pt.vector("loc", dtype="float64")
size = none_type_t("none_size")
rv = normal(loc, size=size)
rv.eval({loc: np.arange(5, dtype="float64"), size: None}, mode="FAST_COMPILE")
......@@ -7,9 +7,11 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.utils import (
RandomStream,
broadcast_params,
normalize_size_param,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.type import matrix, tensor
from pytensor.tensor.type import TensorType, matrix, tensor
from pytensor.tensor.type_other import NoneTypeT, none_type_t
from tests import unittest_tools as utt
......@@ -327,3 +329,22 @@ def test_supp_shape_from_ref_param_shape():
ref_param_idx=1,
)
assert res == (3, 4)
def test_normalize_size_param():
assert normalize_size_param(None).type == NoneTypeT()
sym_none_size = none_type_t()
assert normalize_size_param(sym_none_size) is sym_none_size
empty_size = normalize_size_param(())
assert empty_size.type == TensorType(dtype="int64", shape=(0,))
int_size = normalize_size_param(5)
assert int_size.type == TensorType(dtype="int64", shape=(1,))
seq_int_size = normalize_size_param((5, 3, 4))
assert seq_int_size.type == TensorType(dtype="int64", shape=(3,))
sym_tensor_size = tensor(shape=(3,), dtype="int64")
assert normalize_size_param(sym_tensor_size) is sym_tensor_size
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论