提交 3e9c6a4f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Introduce signature instead of ndim_supp and ndims_params

上级 a576fa2c
import warnings
from collections.abc import Sequence from collections.abc import Sequence
from copy import copy from copy import copy
from typing import cast from typing import cast
...@@ -28,6 +29,7 @@ from pytensor.tensor.random.utils import ( ...@@ -28,6 +29,7 @@ from pytensor.tensor.random.utils import (
from pytensor.tensor.shape import shape_tuple from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -42,7 +44,7 @@ class RandomVariable(Op): ...@@ -42,7 +44,7 @@ class RandomVariable(Op):
_output_type_depends_on_input_value = True _output_type_depends_on_input_value = True
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace") __props__ = ("name", "signature", "dtype", "inplace")
default_output = 1 default_output = 1
def __init__( def __init__(
...@@ -50,8 +52,9 @@ class RandomVariable(Op): ...@@ -50,8 +52,9 @@ class RandomVariable(Op):
name=None, name=None,
ndim_supp=None, ndim_supp=None,
ndims_params=None, ndims_params=None,
dtype=None, dtype: str | None = None,
inplace=None, inplace=None,
signature: str | None = None,
): ):
"""Create a random variable `Op`. """Create a random variable `Op`.
...@@ -59,44 +62,63 @@ class RandomVariable(Op): ...@@ -59,44 +62,63 @@ class RandomVariable(Op):
---------- ----------
name: str name: str
The `Op`'s display name. The `Op`'s display name.
ndim_supp: int signature: str
Total number of dimensions for a single draw of the random variable Numpy-like vectorized signature of the random variable.
(e.g. a multivariate normal draw is 1D, so ``ndim_supp = 1``).
ndims_params: list of int
Number of dimensions for each distribution parameter when the
parameters only specify a single drawn of the random variable
(e.g. a multivariate normal's mean is 1D and covariance is 2D, so
``ndims_params = [1, 2]``).
dtype: str (optional) dtype: str (optional)
The dtype of the sampled output. If the value ``"floatX"`` is The dtype of the sampled output. If the value ``"floatX"`` is
given, then ``dtype`` is set to ``pytensor.config.floatX``. If given, then ``dtype`` is set to ``pytensor.config.floatX``. If
``None`` (the default), the `dtype` keyword must be set when ``None`` (the default), the `dtype` keyword must be set when
`RandomVariable.make_node` is called. `RandomVariable.make_node` is called.
inplace: boolean (optional) inplace: boolean (optional)
Determine whether or not the underlying rng state is updated Determine whether the underlying rng state is mutated or copied.
in-place or not (i.e. copied).
""" """
super().__init__() super().__init__()
self.name = name or getattr(self, "name") self.name = name or getattr(self, "name")
self.ndim_supp = (
ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp") ndim_supp = (
ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp", None)
) )
self.ndims_params = ( if ndim_supp is not None:
ndims_params if ndims_params is not None else getattr(self, "ndims_params") warnings.warn(
"ndim_supp is deprecated. Provide signature instead.", FutureWarning
)
self.ndim_supp = ndim_supp
ndims_params = (
ndims_params
if ndims_params is not None
else getattr(self, "ndims_params", None)
) )
if ndims_params is not None:
warnings.warn(
"ndims_params is deprecated. Provide signature instead.", FutureWarning
)
if not isinstance(ndims_params, Sequence):
raise TypeError("Parameter ndims_params must be sequence type.")
self.ndims_params = tuple(ndims_params)
self.signature = signature or getattr(self, "signature", None)
if self.signature is not None:
# Assume a single output. Several methods need to be updated to handle multiple outputs.
self.inputs_sig, [self.output_sig] = _parse_gufunc_signature(self.signature)
self.ndims_params = [len(input_sig) for input_sig in self.inputs_sig]
self.ndim_supp = len(self.output_sig)
else:
if (
getattr(self, "ndim_supp", None) is None
or getattr(self, "ndims_params", None) is None
):
raise ValueError("signature must be provided")
else:
self.signature = safe_signature(self.ndims_params, [self.ndim_supp])
self.dtype = dtype or getattr(self, "dtype", None) self.dtype = dtype or getattr(self, "dtype", None)
self.inplace = ( self.inplace = (
inplace if inplace is not None else getattr(self, "inplace", False) inplace if inplace is not None else getattr(self, "inplace", False)
) )
if not isinstance(self.ndims_params, Sequence):
raise TypeError("Parameter ndims_params must be sequence type.")
self.ndims_params = tuple(self.ndims_params)
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
...@@ -120,8 +142,31 @@ class RandomVariable(Op): ...@@ -120,8 +142,31 @@ class RandomVariable(Op):
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`, values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
might have `support_shape=(steps,)`. might have `support_shape=(steps,)`.
""" """
if self.signature is not None:
# Signature could indicate fixed numerical shapes
# As per https://numpy.org/neps/nep-0020-gufunc-signature-enhancement.html
output_sig = self.output_sig
core_out_shape = {
dim: int(dim) if str.isnumeric(dim) else None for dim in self.output_sig
}
# Try to infer missing support dims from signature of params
for param, param_sig, ndim_params in zip(
dist_params, self.inputs_sig, self.ndims_params
):
if ndim_params == 0:
continue
for param_dim, dim in zip(param.shape[-ndim_params:], param_sig):
if dim in core_out_shape and core_out_shape[dim] is None:
core_out_shape[dim] = param_dim
if all(dim is not None for dim in core_out_shape.values()):
# We have all we need
return [core_out_shape[dim] for dim in output_sig]
raise NotImplementedError( raise NotImplementedError(
"`_supp_shape_from_params` must be implemented for multivariate RVs" "`_supp_shape_from_params` must be implemented for multivariate RVs "
"when signature is not sufficient to infer the support shape"
) )
def rng_fn(self, rng, *args, **kwargs) -> int | float | np.ndarray: def rng_fn(self, rng, *args, **kwargs) -> int | float | np.ndarray:
...@@ -129,7 +174,24 @@ class RandomVariable(Op): ...@@ -129,7 +174,24 @@ class RandomVariable(Op):
return getattr(rng, self.name)(*args, **kwargs) return getattr(rng, self.name)(*args, **kwargs)
def __str__(self): def __str__(self):
props_str = ", ".join(f"{getattr(self, prop)}" for prop in self.__props__[1:]) # Only show signature from core props
if signature := self.signature:
# inp, out = signature.split("->")
# extended_signature = f"[rng],[size],{inp}->[rng],{out}"
# core_props = [extended_signature]
core_props = [f'"{signature}"']
else:
# Far back compat
core_props = [str(self.ndim_supp), str(self.ndims_params)]
# Add any extra props that the subclass may have
extra_props = [
str(getattr(self, prop))
for prop in self.__props__
if prop not in RandomVariable.__props__
]
props_str = ", ".join(core_props + extra_props)
return f"{self.name}_rv{{{props_str}}}" return f"{self.name}_rv{{{props_str}}}"
def _infer_shape( def _infer_shape(
...@@ -298,11 +360,11 @@ class RandomVariable(Op): ...@@ -298,11 +360,11 @@ class RandomVariable(Op):
dtype_idx = constant(all_dtypes.index(dtype), dtype="int64") dtype_idx = constant(all_dtypes.index(dtype), dtype="int64")
else: else:
dtype_idx = constant(dtype, dtype="int64") dtype_idx = constant(dtype, dtype="int64")
dtype = all_dtypes[dtype_idx.data]
outtype = TensorType(dtype=dtype, shape=static_shape) dtype = all_dtypes[dtype_idx.data]
out_var = outtype()
inputs = (rng, size, dtype_idx, *dist_params) inputs = (rng, size, dtype_idx, *dist_params)
out_var = TensorType(dtype=dtype, shape=static_shape)()
outputs = (rng.type(), out_var) outputs = (rng.type(), out_var)
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
...@@ -395,9 +457,8 @@ def vectorize_random_variable( ...@@ -395,9 +457,8 @@ def vectorize_random_variable(
# We extend it to accommodate the new input batch dimensions. # We extend it to accommodate the new input batch dimensions.
# Otherwise, we assume the new size already has the right values # Otherwise, we assume the new size already has the right values
# Need to make parameters implicit broadcasting explicit original_dist_params = op.dist_params(node)
original_dist_params = node.inputs[3:] old_size = op.size_param(node)
old_size = node.inputs[1]
len_old_size = get_vector_length(old_size) len_old_size = get_vector_length(old_size)
original_expanded_dist_params = explicit_expand_dims( original_expanded_dist_params = explicit_expand_dims(
......
import re
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
...@@ -164,9 +166,9 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): ...@@ -164,9 +166,9 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
a_vector_param = arange(a_scalar_param) a_vector_param = arange(a_scalar_param)
new_props_dict = op._props_dict().copy() new_props_dict = op._props_dict().copy()
new_ndims_params = list(op.ndims_params) # Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)"
new_ndims_params[0] += 1 # I.e., we substitute the first `()` by `(a)`
new_props_dict["ndims_params"] = new_ndims_params new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1)
new_op = type(op)(**new_props_dict) new_op = type(op)(**new_props_dict)
return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs
......
...@@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params): ...@@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params):
def explicit_expand_dims( def explicit_expand_dims(
params: Sequence[TensorVariable], params: Sequence[TensorVariable],
ndim_params: tuple[int], ndim_params: Sequence[int],
size_length: int = 0, size_length: int = 0,
) -> list[TensorVariable]: ) -> list[TensorVariable]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size.""" """Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
...@@ -137,7 +137,7 @@ def explicit_expand_dims( ...@@ -137,7 +137,7 @@ def explicit_expand_dims(
# See: https://github.com/pymc-devs/pytensor/issues/568 # See: https://github.com/pymc-devs/pytensor/issues/568
max_batch_dims = size_length max_batch_dims = size_length
else: else:
max_batch_dims = max(batch_dims) max_batch_dims = max(batch_dims, default=0)
new_params = [] new_params = []
for new_param, batch_dim in zip(params, batch_dims): for new_param, batch_dim in zip(params, batch_dims):
...@@ -354,6 +354,11 @@ def supp_shape_from_ref_param_shape( ...@@ -354,6 +354,11 @@ def supp_shape_from_ref_param_shape(
out: tuple out: tuple
Representing the support shape for a `RandomVariable` with the given `dist_params`. Representing the support shape for a `RandomVariable` with the given `dist_params`.
Notes
_____
This helper is no longer necessary when using signatures in `RandomVariable` subclasses.
""" """
if ndim_supp <= 0: if ndim_supp <= 0:
raise ValueError("ndim_supp must be greater than 0") raise ValueError("ndim_supp must be greater than 0")
......
...@@ -169,7 +169,8 @@ _DIMENSION_NAME = r"\w+" ...@@ -169,7 +169,8 @@ _DIMENSION_NAME = r"\w+"
_CORE_DIMENSION_LIST = f"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*)?" _CORE_DIMENSION_LIST = f"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*)?"
_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" _ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)"
_ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*" _ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*"
_SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$" # Allow no inputs
_SIGNATURE = f"^(?:{_ARGUMENT_LIST})?->{_ARGUMENT_LIST}$"
def _parse_gufunc_signature( def _parse_gufunc_signature(
...@@ -200,6 +201,8 @@ def _parse_gufunc_signature( ...@@ -200,6 +201,8 @@ def _parse_gufunc_signature(
tuple(re.findall(_DIMENSION_NAME, arg)) tuple(re.findall(_DIMENSION_NAME, arg))
for arg in re.findall(_ARGUMENT, arg_list) for arg in re.findall(_ARGUMENT, arg_list)
] ]
if arg_list # ignore no inputs
else []
for arg_list in signature.split("->") for arg_list in signature.split("->")
) )
......
...@@ -771,8 +771,7 @@ def test_random_unimplemented(): ...@@ -771,8 +771,7 @@ def test_random_unimplemented():
class NonExistentRV(RandomVariable): class NonExistentRV(RandomVariable):
name = "non-existent" name = "non-existent"
ndim_supp = 0 signature = "->()"
ndims_params = []
dtype = "floatX" dtype = "floatX"
def __call__(self, size=None, **kwargs): def __call__(self, size=None, **kwargs):
...@@ -798,8 +797,7 @@ def test_random_custom_implementation(): ...@@ -798,8 +797,7 @@ def test_random_custom_implementation():
class CustomRV(RandomVariable): class CustomRV(RandomVariable):
name = "non-existent" name = "non-existent"
ndim_supp = 0 signature = "->()"
ndims_params = []
dtype = "floatX" dtype = "floatX"
def __call__(self, size=None, **kwargs): def __call__(self, size=None, **kwargs):
......
...@@ -74,52 +74,28 @@ def apply_local_rewrite_to_rv( ...@@ -74,52 +74,28 @@ def apply_local_rewrite_to_rv(
return new_out, f_inputs, dist_st, f_rewritten return new_out, f_inputs, dist_st, f_rewritten
def test_inplace_rewrites(): class TestRVExpraProps(RandomVariable):
out = normal(0, 1) name = "test"
out.owner.inputs[0].default_update = out.owner.outputs[0] signature = "()->()"
__props__ = ("name", "signature", "dtype", "inplace", "extra")
dtype = "floatX"
_print_name = ("TestExtraProps", "\\operatorname{TestExtra_props}")
assert out.owner.op.inplace is False def __init__(self, extra, *args, **kwargs):
self.extra = extra
super().__init__(*args, **kwargs)
f = function( def rng_fn(self, rng, dtype, sigma, size):
[], return rng.normal(scale=sigma, size=size)
out,
mode="FAST_RUN",
)
(new_out, new_rng) = f.maker.fgraph.outputs
assert new_out.type == out.type
assert isinstance(new_out.owner.op, type(out.owner.op))
assert new_out.owner.op.inplace is True
assert all(
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
)
assert np.array_equal(new_out.owner.inputs[1].data, [])
def test_inplace_rewrites_extra_props():
class Test(RandomVariable):
name = "test"
ndim_supp = 0
ndims_params = [0]
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace", "extra")
dtype = "floatX"
_print_name = ("Test", "\\operatorname{Test}")
def __init__(self, extra, *args, **kwargs):
self.extra = extra
super().__init__(*args, **kwargs)
def make_node(self, rng, size, dtype, sigma):
return super().make_node(rng, size, dtype, sigma)
def rng_fn(self, rng, sigma, size):
return rng.normal(scale=sigma, size=size)
out = Test(extra="some value")(1)
out.owner.inputs[0].default_update = out.owner.outputs[0]
assert out.owner.op.inplace is False @pytest.mark.parametrize("rv_op", [normal, TestRVExpraProps(extra="some value")])
def test_inplace_rewrites(rv_op):
out = rv_op(np.e)
node = out.owner
op = node.op
node.inputs[0].default_update = node.outputs[0]
assert op.inplace is False
f = function( f = function(
[], [],
...@@ -129,9 +105,10 @@ def test_inplace_rewrites_extra_props(): ...@@ -129,9 +105,10 @@ def test_inplace_rewrites_extra_props():
(new_out, new_rng) = f.maker.fgraph.outputs (new_out, new_rng) = f.maker.fgraph.outputs
assert new_out.type == out.type assert new_out.type == out.type
assert isinstance(new_out.owner.op, type(out.owner.op)) new_node = new_out.owner
assert new_out.owner.op.inplace is True new_op = new_node.op
assert new_out.owner.op.extra == out.owner.op.extra assert isinstance(new_op, type(op))
assert new_op._props_dict() == (op._props_dict() | {"inplace": True})
assert all( assert all(
np.array_equal(a.data, b.data) np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:]) for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
......
...@@ -1463,11 +1463,8 @@ def batched_unweighted_choice_without_replacement_tester( ...@@ -1463,11 +1463,8 @@ def batched_unweighted_choice_without_replacement_tester(
rng = shared(rng_ctor()) rng = shared(rng_ctor())
# Batched a implicit size # Batched a implicit size
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, signature="(a0,a1),(1)->(s0,a1)",
ndims_params=[a_core_ndim, core_shape_len],
dtype="int64", dtype="int64",
) )
...@@ -1483,11 +1480,8 @@ def batched_unweighted_choice_without_replacement_tester( ...@@ -1483,11 +1480,8 @@ def batched_unweighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10) & (draw < (i + 1) * 10)) assert np.all((draw >= i * 10) & (draw < (i + 1) * 10))
# Explicit size broadcasts beyond a # Explicit size broadcasts beyond a
a_core_ndim = 2
core_shape_len = 2
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, signature="(a0,a1),(2)->(s0,s1,a1)",
ndims_params=[a_core_ndim, len(core_shape)],
dtype="int64", dtype="int64",
) )
...@@ -1515,12 +1509,8 @@ def batched_weighted_choice_without_replacement_tester( ...@@ -1515,12 +1509,8 @@ def batched_weighted_choice_without_replacement_tester(
""" """
rng = shared(rng_ctor()) rng = shared(rng_ctor())
# 3 ndims params indicates p is passed
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, signature="(a0,a1),(a0),(1)->(s0,a1)",
ndims_params=[a_core_ndim, 1, 1],
dtype="int64", dtype="int64",
) )
...@@ -1540,11 +1530,8 @@ def batched_weighted_choice_without_replacement_tester( ...@@ -1540,11 +1530,8 @@ def batched_weighted_choice_without_replacement_tester(
# p and a are batched # p and a are batched
# Test implicit arange # Test implicit arange
a_core_ndim = 0
core_shape_len = 2
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, signature="(),(a),(2)->(s0,s1)",
ndims_params=[a_core_ndim, 1, 1],
dtype="int64", dtype="int64",
) )
a = 6 a = 6
...@@ -1566,11 +1553,8 @@ def batched_weighted_choice_without_replacement_tester( ...@@ -1566,11 +1553,8 @@ def batched_weighted_choice_without_replacement_tester(
assert set(draw) == set(range(i, 6, 2)) assert set(draw) == set(range(i, 6, 2))
# Size broadcasts beyond a # Size broadcasts beyond a
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, signature="(a0,a1),(a0),(1)->(s0,a1)",
ndims_params=[a_core_ndim, 1, 1],
dtype="int64", dtype="int64",
) )
a = np.arange(4 * 5 * 2).reshape((4, 5, 2)) a = np.arange(4 * 5 * 2).reshape((4, 5, 2))
......
...@@ -23,14 +23,13 @@ def test_RandomVariable_basics(strict_test_value_flags): ...@@ -23,14 +23,13 @@ def test_RandomVariable_basics(strict_test_value_flags):
str_res = str( str_res = str(
RandomVariable( RandomVariable(
"normal", "normal",
0, signature="(),()->()",
[0, 0], dtype="float32",
"float32", inplace=False,
inplace=True,
) )
) )
assert str_res == "normal_rv{0, (0, 0), float32, True}" assert str_res == 'normal_rv{"(),()->()"}'
# `ndims_params` should be a `Sequence` type # `ndims_params` should be a `Sequence` type
with pytest.raises(TypeError, match="^Parameter ndims_params*"): with pytest.raises(TypeError, match="^Parameter ndims_params*"):
...@@ -64,9 +63,7 @@ def test_RandomVariable_basics(strict_test_value_flags): ...@@ -64,9 +63,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
# Confirm that `inplace` works # Confirm that `inplace` works
rv = RandomVariable( rv = RandomVariable(
"normal", "normal",
0, signature="(),()->()",
[0, 0],
"normal",
inplace=True, inplace=True,
) )
...@@ -74,7 +71,7 @@ def test_RandomVariable_basics(strict_test_value_flags): ...@@ -74,7 +71,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
assert rv.destroy_map == {0: [0]} assert rv.destroy_map == {0: [0]}
# A no-params `RandomVariable` # A no-params `RandomVariable`
rv = RandomVariable(name="test_rv", ndim_supp=0, ndims_params=()) rv = RandomVariable(name="test_rv", signature="->()")
with pytest.raises(TypeError): with pytest.raises(TypeError):
rv.make_node(rng=1) rv.make_node(rng=1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论