提交 3e9c6a4f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Introduce signature instead of ndim_supp and ndims_params

上级 a576fa2c
import warnings
from collections.abc import Sequence
from copy import copy
from typing import cast
......@@ -28,6 +29,7 @@ from pytensor.tensor.random.utils import (
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.variable import TensorVariable
......@@ -42,7 +44,7 @@ class RandomVariable(Op):
_output_type_depends_on_input_value = True
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace")
__props__ = ("name", "signature", "dtype", "inplace")
default_output = 1
def __init__(
......@@ -50,8 +52,9 @@ class RandomVariable(Op):
name=None,
ndim_supp=None,
ndims_params=None,
dtype=None,
dtype: str | None = None,
inplace=None,
signature: str | None = None,
):
"""Create a random variable `Op`.
......@@ -59,44 +62,63 @@ class RandomVariable(Op):
----------
name: str
The `Op`'s display name.
ndim_supp: int
Total number of dimensions for a single draw of the random variable
(e.g. a multivariate normal draw is 1D, so ``ndim_supp = 1``).
ndims_params: list of int
Number of dimensions for each distribution parameter when the
parameters only specify a single drawn of the random variable
(e.g. a multivariate normal's mean is 1D and covariance is 2D, so
``ndims_params = [1, 2]``).
signature: str
Numpy-like vectorized signature of the random variable.
dtype: str (optional)
The dtype of the sampled output. If the value ``"floatX"`` is
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
``None`` (the default), the `dtype` keyword must be set when
`RandomVariable.make_node` is called.
inplace: boolean (optional)
Determine whether or not the underlying rng state is updated
in-place or not (i.e. copied).
Determine whether the underlying rng state is mutated or copied.
"""
super().__init__()
self.name = name or getattr(self, "name")
self.ndim_supp = (
ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp")
ndim_supp = (
ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp", None)
)
self.ndims_params = (
ndims_params if ndims_params is not None else getattr(self, "ndims_params")
if ndim_supp is not None:
warnings.warn(
"ndim_supp is deprecated. Provide signature instead.", FutureWarning
)
self.ndim_supp = ndim_supp
ndims_params = (
ndims_params
if ndims_params is not None
else getattr(self, "ndims_params", None)
)
if ndims_params is not None:
warnings.warn(
"ndims_params is deprecated. Provide signature instead.", FutureWarning
)
if not isinstance(ndims_params, Sequence):
raise TypeError("Parameter ndims_params must be sequence type.")
self.ndims_params = tuple(ndims_params)
self.signature = signature or getattr(self, "signature", None)
if self.signature is not None:
# Assume a single output. Several methods need to be updated to handle multiple outputs.
self.inputs_sig, [self.output_sig] = _parse_gufunc_signature(self.signature)
self.ndims_params = [len(input_sig) for input_sig in self.inputs_sig]
self.ndim_supp = len(self.output_sig)
else:
if (
getattr(self, "ndim_supp", None) is None
or getattr(self, "ndims_params", None) is None
):
raise ValueError("signature must be provided")
else:
self.signature = safe_signature(self.ndims_params, [self.ndim_supp])
self.dtype = dtype or getattr(self, "dtype", None)
self.inplace = (
inplace if inplace is not None else getattr(self, "inplace", False)
)
if not isinstance(self.ndims_params, Sequence):
raise TypeError("Parameter ndims_params must be sequence type.")
self.ndims_params = tuple(self.ndims_params)
if self.inplace:
self.destroy_map = {0: [0]}
......@@ -120,8 +142,31 @@ class RandomVariable(Op):
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
might have `support_shape=(steps,)`.
"""
if self.signature is not None:
# Signature could indicate fixed numerical shapes
# As per https://numpy.org/neps/nep-0020-gufunc-signature-enhancement.html
output_sig = self.output_sig
core_out_shape = {
dim: int(dim) if str.isnumeric(dim) else None for dim in self.output_sig
}
# Try to infer missing support dims from signature of params
for param, param_sig, ndim_params in zip(
dist_params, self.inputs_sig, self.ndims_params
):
if ndim_params == 0:
continue
for param_dim, dim in zip(param.shape[-ndim_params:], param_sig):
if dim in core_out_shape and core_out_shape[dim] is None:
core_out_shape[dim] = param_dim
if all(dim is not None for dim in core_out_shape.values()):
# We have all we need
return [core_out_shape[dim] for dim in output_sig]
raise NotImplementedError(
"`_supp_shape_from_params` must be implemented for multivariate RVs"
"`_supp_shape_from_params` must be implemented for multivariate RVs "
"when signature is not sufficient to infer the support shape"
)
def rng_fn(self, rng, *args, **kwargs) -> int | float | np.ndarray:
......@@ -129,7 +174,24 @@ class RandomVariable(Op):
return getattr(rng, self.name)(*args, **kwargs)
def __str__(self):
props_str = ", ".join(f"{getattr(self, prop)}" for prop in self.__props__[1:])
# Only show signature from core props
if signature := self.signature:
# inp, out = signature.split("->")
# extended_signature = f"[rng],[size],{inp}->[rng],{out}"
# core_props = [extended_signature]
core_props = [f'"{signature}"']
else:
# Far back compat
core_props = [str(self.ndim_supp), str(self.ndims_params)]
# Add any extra props that the subclass may have
extra_props = [
str(getattr(self, prop))
for prop in self.__props__
if prop not in RandomVariable.__props__
]
props_str = ", ".join(core_props + extra_props)
return f"{self.name}_rv{{{props_str}}}"
def _infer_shape(
......@@ -298,11 +360,11 @@ class RandomVariable(Op):
dtype_idx = constant(all_dtypes.index(dtype), dtype="int64")
else:
dtype_idx = constant(dtype, dtype="int64")
dtype = all_dtypes[dtype_idx.data]
outtype = TensorType(dtype=dtype, shape=static_shape)
out_var = outtype()
dtype = all_dtypes[dtype_idx.data]
inputs = (rng, size, dtype_idx, *dist_params)
out_var = TensorType(dtype=dtype, shape=static_shape)()
outputs = (rng.type(), out_var)
return Apply(self, inputs, outputs)
......@@ -395,9 +457,8 @@ def vectorize_random_variable(
# We extend it to accommodate the new input batch dimensions.
# Otherwise, we assume the new size already has the right values
# Need to make parameters implicit broadcasting explicit
original_dist_params = node.inputs[3:]
old_size = node.inputs[1]
original_dist_params = op.dist_params(node)
old_size = op.size_param(node)
len_old_size = get_vector_length(old_size)
original_expanded_dist_params = explicit_expand_dims(
......
import re
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.rewriting.db import SequenceDB
......@@ -164,9 +166,9 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
a_vector_param = arange(a_scalar_param)
new_props_dict = op._props_dict().copy()
new_ndims_params = list(op.ndims_params)
new_ndims_params[0] += 1
new_props_dict["ndims_params"] = new_ndims_params
# Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)"
# I.e., we substitute the first `()` by `(a)`
new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1)
new_op = type(op)(**new_props_dict)
return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs
......
......@@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params):
def explicit_expand_dims(
params: Sequence[TensorVariable],
ndim_params: tuple[int],
ndim_params: Sequence[int],
size_length: int = 0,
) -> list[TensorVariable]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
......@@ -137,7 +137,7 @@ def explicit_expand_dims(
# See: https://github.com/pymc-devs/pytensor/issues/568
max_batch_dims = size_length
else:
max_batch_dims = max(batch_dims)
max_batch_dims = max(batch_dims, default=0)
new_params = []
for new_param, batch_dim in zip(params, batch_dims):
......@@ -354,6 +354,11 @@ def supp_shape_from_ref_param_shape(
out: tuple
Representing the support shape for a `RandomVariable` with the given `dist_params`.
Notes
_____
This helper is no longer necessary when using signatures in `RandomVariable` subclasses.
"""
if ndim_supp <= 0:
raise ValueError("ndim_supp must be greater than 0")
......
......@@ -169,7 +169,8 @@ _DIMENSION_NAME = r"\w+"
_CORE_DIMENSION_LIST = f"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*)?"
_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)"
_ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*"
_SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$"
# Allow no inputs
_SIGNATURE = f"^(?:{_ARGUMENT_LIST})?->{_ARGUMENT_LIST}$"
def _parse_gufunc_signature(
......@@ -200,6 +201,8 @@ def _parse_gufunc_signature(
tuple(re.findall(_DIMENSION_NAME, arg))
for arg in re.findall(_ARGUMENT, arg_list)
]
if arg_list # ignore no inputs
else []
for arg_list in signature.split("->")
)
......
......@@ -771,8 +771,7 @@ def test_random_unimplemented():
class NonExistentRV(RandomVariable):
name = "non-existent"
ndim_supp = 0
ndims_params = []
signature = "->()"
dtype = "floatX"
def __call__(self, size=None, **kwargs):
......@@ -798,8 +797,7 @@ def test_random_custom_implementation():
class CustomRV(RandomVariable):
name = "non-existent"
ndim_supp = 0
ndims_params = []
signature = "->()"
dtype = "floatX"
def __call__(self, size=None, **kwargs):
......
......@@ -74,52 +74,28 @@ def apply_local_rewrite_to_rv(
return new_out, f_inputs, dist_st, f_rewritten
def test_inplace_rewrites():
out = normal(0, 1)
out.owner.inputs[0].default_update = out.owner.outputs[0]
class TestRVExpraProps(RandomVariable):
name = "test"
signature = "()->()"
__props__ = ("name", "signature", "dtype", "inplace", "extra")
dtype = "floatX"
_print_name = ("TestExtraProps", "\\operatorname{TestExtra_props}")
assert out.owner.op.inplace is False
def __init__(self, extra, *args, **kwargs):
self.extra = extra
super().__init__(*args, **kwargs)
f = function(
[],
out,
mode="FAST_RUN",
)
(new_out, new_rng) = f.maker.fgraph.outputs
assert new_out.type == out.type
assert isinstance(new_out.owner.op, type(out.owner.op))
assert new_out.owner.op.inplace is True
assert all(
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
)
assert np.array_equal(new_out.owner.inputs[1].data, [])
def test_inplace_rewrites_extra_props():
class Test(RandomVariable):
name = "test"
ndim_supp = 0
ndims_params = [0]
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace", "extra")
dtype = "floatX"
_print_name = ("Test", "\\operatorname{Test}")
def __init__(self, extra, *args, **kwargs):
self.extra = extra
super().__init__(*args, **kwargs)
def make_node(self, rng, size, dtype, sigma):
return super().make_node(rng, size, dtype, sigma)
def rng_fn(self, rng, sigma, size):
return rng.normal(scale=sigma, size=size)
def rng_fn(self, rng, dtype, sigma, size):
return rng.normal(scale=sigma, size=size)
out = Test(extra="some value")(1)
out.owner.inputs[0].default_update = out.owner.outputs[0]
assert out.owner.op.inplace is False
@pytest.mark.parametrize("rv_op", [normal, TestRVExpraProps(extra="some value")])
def test_inplace_rewrites(rv_op):
out = rv_op(np.e)
node = out.owner
op = node.op
node.inputs[0].default_update = node.outputs[0]
assert op.inplace is False
f = function(
[],
......@@ -129,9 +105,10 @@ def test_inplace_rewrites_extra_props():
(new_out, new_rng) = f.maker.fgraph.outputs
assert new_out.type == out.type
assert isinstance(new_out.owner.op, type(out.owner.op))
assert new_out.owner.op.inplace is True
assert new_out.owner.op.extra == out.owner.op.extra
new_node = new_out.owner
new_op = new_node.op
assert isinstance(new_op, type(op))
assert new_op._props_dict() == (op._props_dict() | {"inplace": True})
assert all(
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
......
......@@ -1463,11 +1463,8 @@ def batched_unweighted_choice_without_replacement_tester(
rng = shared(rng_ctor())
# Batched a implicit size
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, core_shape_len],
signature="(a0,a1),(1)->(s0,a1)",
dtype="int64",
)
......@@ -1483,11 +1480,8 @@ def batched_unweighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10) & (draw < (i + 1) * 10))
# Explicit size broadcasts beyond a
a_core_ndim = 2
core_shape_len = 2
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, len(core_shape)],
signature="(a0,a1),(2)->(s0,s1,a1)",
dtype="int64",
)
......@@ -1515,12 +1509,8 @@ def batched_weighted_choice_without_replacement_tester(
"""
rng = shared(rng_ctor())
# 3 ndims params indicates p is passed
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, 1, 1],
signature="(a0,a1),(a0),(1)->(s0,a1)",
dtype="int64",
)
......@@ -1540,11 +1530,8 @@ def batched_weighted_choice_without_replacement_tester(
# p and a are batched
# Test implicit arange
a_core_ndim = 0
core_shape_len = 2
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, 1, 1],
signature="(),(a),(2)->(s0,s1)",
dtype="int64",
)
a = 6
......@@ -1566,11 +1553,8 @@ def batched_weighted_choice_without_replacement_tester(
assert set(draw) == set(range(i, 6, 2))
# Size broadcasts beyond a
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, 1, 1],
signature="(a0,a1),(a0),(1)->(s0,a1)",
dtype="int64",
)
a = np.arange(4 * 5 * 2).reshape((4, 5, 2))
......
......@@ -23,14 +23,13 @@ def test_RandomVariable_basics(strict_test_value_flags):
str_res = str(
RandomVariable(
"normal",
0,
[0, 0],
"float32",
inplace=True,
signature="(),()->()",
dtype="float32",
inplace=False,
)
)
assert str_res == "normal_rv{0, (0, 0), float32, True}"
assert str_res == 'normal_rv{"(),()->()"}'
# `ndims_params` should be a `Sequence` type
with pytest.raises(TypeError, match="^Parameter ndims_params*"):
......@@ -64,9 +63,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
# Confirm that `inplace` works
rv = RandomVariable(
"normal",
0,
[0, 0],
"normal",
signature="(),()->()",
inplace=True,
)
......@@ -74,7 +71,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
assert rv.destroy_map == {0: [0]}
# A no-params `RandomVariable`
rv = RandomVariable(name="test_rv", ndim_supp=0, ndims_params=())
rv = RandomVariable(name="test_rv", signature="->()")
with pytest.raises(TypeError):
rv.make_node(rng=1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论