提交 98d73d78 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove RandomVariable dtype input

上级 df32683c
......@@ -114,7 +114,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
if None in static_size:
assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, dtype, *parameters):
def sample_fn(rng, size, *parameters):
# PyTensor uses empty size to represent size = None
if jax.numpy.asarray(size).shape == (0,):
size = None
......@@ -122,7 +122,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
else:
def sample_fn(rng, size, dtype, *parameters):
def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters
)
......
......@@ -123,7 +123,6 @@ def make_numba_random_fn(node, np_random_func):
"size_dims",
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
......@@ -146,7 +145,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
)
random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in dist_params]
["rng", "size"] + [unique_names(i) for i in dist_params]
)
# Now, create a Numba JITable function that implements the `size` parameter
......@@ -243,7 +242,7 @@ def create_numba_random_fn(
np_global_env["numba_vectorize"] = numba_basic.numba_vectorize
unique_names = unique_name_generator(
[np_random_fn_name, *np_global_env.keys(), "rng", "size", "dtype"],
[np_random_fn_name, *np_global_env.keys(), "rng", "size"],
suffix_sep="_",
)
......@@ -310,7 +309,7 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
p_ndim = node.inputs[-1].ndim
@numba_basic.numba_njit
def categorical_rv(rng, size, dtype, p):
def categorical_rv(rng, size, p):
if not size_len:
size_tpl = p.shape[:-1]
else:
......@@ -342,7 +341,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
if alphas_ndim > 1:
@numba_basic.numba_njit
def dirichlet_rv(rng, size, dtype, alphas):
def dirichlet_rv(rng, size, alphas):
if size_len > 0:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
if (
......@@ -365,7 +364,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
else:
@numba_basic.numba_njit
def dirichlet_rv(rng, size, dtype, alphas):
def dirichlet_rv(rng, size, alphas):
size = numba_ndarray.to_fixed_tuple(size, size_len)
return (rng, np.random.dirichlet(alphas, size))
......@@ -388,14 +387,14 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
if op.has_p_param:
@numba_basic.numba_njit
def choice_without_replacement_rv(rng, size, dtype, a, p, core_shape):
def choice_without_replacement_rv(rng, size, a, p, core_shape):
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
samples = np.random.choice(a, size=core_shape, replace=False, p=p)
return (rng, samples)
else:
@numba_basic.numba_njit
def choice_without_replacement_rv(rng, size, dtype, a, core_shape):
def choice_without_replacement_rv(rng, size, a, core_shape):
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
samples = np.random.choice(a, size=core_shape, replace=False)
return (rng, samples)
......@@ -411,7 +410,7 @@ def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
@numba_basic.numba_njit
def permutation_rv(rng, size, dtype, x):
def permutation_rv(rng, size, x):
if batch_ndim:
x_core_shape = x.shape[x_batch_ndim:]
if size_is_none:
......
......@@ -27,7 +27,7 @@ from pytensor.tensor.random.utils import (
normalize_size_param,
)
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.variable import TensorVariable
......@@ -65,7 +65,7 @@ class RandomVariable(Op):
signature: str
Numpy-like vectorized signature of the random variable.
dtype: str (optional)
The dtype of the sampled output. If the value ``"floatX"`` is
The default dtype of the sampled output. If the value ``"floatX"`` is
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
``None`` (the default), the `dtype` keyword must be set when
`RandomVariable.make_node` is called.
......@@ -287,8 +287,8 @@ class RandomVariable(Op):
return shape
def infer_shape(self, fgraph, node, input_shapes):
_, size, _, *dist_params = node.inputs
_, size_shape, _, *param_shapes = input_shapes
_, size, *dist_params = node.inputs
_, size_shape, *param_shapes = input_shapes
try:
size_len = get_vector_length(size)
......@@ -302,14 +302,34 @@ class RandomVariable(Op):
return [None, list(shape)]
def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
res = super().__call__(rng, size, dtype, *args, **kwargs)
if dtype is None:
dtype = self.dtype
if dtype == "floatX":
dtype = config.floatX
# We need to recreate the Op with the right dtype
if dtype != self.dtype:
# Check we are not switching from float to int
if self.dtype is not None:
if dtype.startswith("float") != self.dtype.startswith("float"):
raise ValueError(
f"Cannot change the dtype of a {self.name} RV from {self.dtype} to {dtype}"
)
props = self._props_dict()
props["dtype"] = dtype
new_op = type(self)(**props)
return new_op.__call__(
*args, size=size, name=name, rng=rng, dtype=dtype, **kwargs
)
res = super().__call__(rng, size, *args, **kwargs)
if name is not None:
res.name = name
return res
def make_node(self, rng, size, dtype, *dist_params):
def make_node(self, rng, size, *dist_params):
"""Create a random variable node.
Parameters
......@@ -349,23 +369,10 @@ class RandomVariable(Op):
shape = self._infer_shape(size, dist_params)
_, static_shape = infer_static_shape(shape)
dtype = self.dtype or dtype
if dtype == "floatX":
dtype = config.floatX
elif dtype is None or (isinstance(dtype, str) and dtype not in all_dtypes):
raise TypeError("dtype is unspecified")
if isinstance(dtype, str):
dtype_idx = constant(all_dtypes.index(dtype), dtype="int64")
else:
dtype_idx = constant(dtype, dtype="int64")
dtype = all_dtypes[dtype_idx.data]
inputs = (rng, size, dtype_idx, *dist_params)
out_var = TensorType(dtype=dtype, shape=static_shape)()
outputs = (rng.type(), out_var)
inputs = (rng, size, *dist_params)
out_type = TensorType(dtype=self.dtype, shape=static_shape)
outputs = (rng.type(), out_type())
return Apply(self, inputs, outputs)
......@@ -382,14 +389,12 @@ class RandomVariable(Op):
def dist_params(self, node) -> Sequence[Variable]:
"""Return the node inpust corresponding to dist params"""
return node.inputs[3:]
return node.inputs[2:]
def perform(self, node, inputs, outputs):
rng_var_out, smpl_out = outputs
rng, size, dtype, *args = inputs
out_var = node.outputs[1]
rng, size, *args = inputs
# If `size == []`, that means no size is enforced, and NumPy is trusted
# to draw the appropriate number of samples, NumPy uses `size=None` to
......@@ -408,11 +413,8 @@ class RandomVariable(Op):
smpl_val = self.rng_fn(rng, *([*args, size]))
if (
not isinstance(smpl_val, np.ndarray)
or str(smpl_val.dtype) != out_var.type.dtype
):
smpl_val = _asarray(smpl_val, dtype=out_var.type.dtype)
if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype:
smpl_val = _asarray(smpl_val, dtype=self.dtype)
smpl_out[0] = smpl_val
......@@ -463,7 +465,7 @@ default_rng = DefaultGeneratorMakerOp()
@_vectorize_node.register(RandomVariable)
def vectorize_random_variable(
op: RandomVariable, node: Apply, rng, size, dtype, *dist_params
op: RandomVariable, node: Apply, rng, size, *dist_params
) -> Apply:
# If size was provided originally and a new size hasn't been provided,
# We extend it to accommodate the new input batch dimensions.
......@@ -491,4 +493,4 @@ def vectorize_random_variable(
new_size_dims = broadcasted_batch_shape[:new_ndim]
size = concatenate([new_size_dims, size])
return op.make_node(rng, size, dtype, *dist_params)
return op.make_node(rng, size, *dist_params)
......@@ -81,7 +81,7 @@ def local_rv_size_lift(fgraph, node):
if not isinstance(node.op, RandomVariable):
return
rng, size, dtype, *dist_params = node.inputs
rng, size, *dist_params = node.inputs
dist_params = broadcast_params(dist_params, node.op.ndims_params)
......@@ -105,7 +105,7 @@ def local_rv_size_lift(fgraph, node):
else:
return
new_node = node.op.make_node(rng, None, dtype, *dist_params)
new_node = node.op.make_node(rng, None, *dist_params)
if config.compute_test_value != "off":
compute_test_value(new_node)
......@@ -141,7 +141,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
return False
rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs
rng, size, *dist_params = rv_node.inputs
rv = rv_node.default_output()
# Check that Dimshuffle does not affect support dims
......@@ -185,7 +185,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
)
new_dist_params.append(param.dimshuffle(param_new_order))
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
new_node = rv_op.make_node(rng, new_size, *new_dist_params)
if config.compute_test_value != "off":
compute_test_value(new_node)
......@@ -233,7 +233,7 @@ def local_subtensor_rv_lift(fgraph, node):
return None
rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs
rng, size, *dist_params = rv_node.inputs
# Parse indices
idx_list = getattr(subtensor_op, "idx_list", None)
......@@ -346,7 +346,7 @@ def local_subtensor_rv_lift(fgraph, node):
new_dist_params.append(batch_param[tuple(batch_indices)])
# Create new RV
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
new_node = rv_op.make_node(rng, new_size, *new_dist_params)
new_rv = new_node.default_output()
copy_stack_trace(rv, new_rv)
......
......@@ -12,6 +12,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import (
NormalRV,
categorical,
dirichlet,
multinomial,
......@@ -397,7 +398,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
)
if lifted:
assert new_out.owner.op == dist_op
assert isinstance(new_out.owner.op, type(dist_op))
assert all(
isinstance(i.owner.op, DimShuffle)
for i in new_out.owner.op.dist_params(new_out.owner)
......@@ -832,7 +833,7 @@ def test_Subtensor_lift_restrictions():
subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
assert subtensor_node == y.owner
assert isinstance(subtensor_node.op, Subtensor)
assert subtensor_node.inputs[0].owner.op == normal
assert isinstance(subtensor_node.inputs[0].owner.op, NormalRV)
z = pt.ones(x.shape) - x[1]
......@@ -850,7 +851,7 @@ def test_Subtensor_lift_restrictions():
EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
assert rv_node.op == normal
assert isinstance(rv_node.op, NormalRV)
assert isinstance(rv_node.inputs[-1].owner.op, Subtensor)
assert isinstance(rv_node.inputs[-2].owner.op, Subtensor)
......@@ -872,7 +873,7 @@ def test_Dimshuffle_lift_restrictions():
dimshuffle_node = fg.outputs[0].owner.inputs[1].owner
assert dimshuffle_node == y.owner
assert isinstance(dimshuffle_node.op, DimShuffle)
assert dimshuffle_node.inputs[0].owner.op == normal
assert isinstance(dimshuffle_node.inputs[0].owner.op, NormalRV)
z = pt.ones(x.shape) - y
......@@ -890,7 +891,7 @@ def test_Dimshuffle_lift_restrictions():
EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
rv_node = fg.outputs[0].owner.inputs[1].owner
assert rv_node.op == normal
assert isinstance(rv_node.op, NormalRV)
assert isinstance(rv_node.inputs[-1].owner.op, DimShuffle)
assert isinstance(rv_node.inputs[-2].owner.op, DimShuffle)
......
......@@ -3,14 +3,14 @@ import pytest
import pytensor.tensor as pt
from pytensor import config, function
from pytensor.gradient import NullTypeGradError, grad
from pytensor.graph.replace import vectorize_node
from pytensor.graph.replace import vectorize_graph
from pytensor.raise_op import Assert
from pytensor.tensor.math import eq
from pytensor.tensor.random import normal
from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import all_dtypes, iscalar, tensor
from pytensor.tensor.type import iscalar, tensor
@pytest.fixture(scope="function", autouse=False)
......@@ -51,15 +51,6 @@ def test_RandomVariable_basics(strict_test_value_flags):
inplace=True,
)(0, 1, size={1, 2})
# No dtype
with pytest.raises(TypeError, match="^dtype*"):
RandomVariable(
"normal",
0,
[0, 0],
inplace=True,
)(0, 1)
# Confirm that `inplace` works
rv = RandomVariable(
"normal",
......@@ -80,16 +71,19 @@ def test_RandomVariable_basics(strict_test_value_flags):
rv_shape = rv._infer_shape(pt.constant([]), (), [])
assert rv_shape.equals(pt.constant([], dtype="int64"))
# Integer-specified `dtype`
dtype_1 = all_dtypes[1]
rv_node = rv.make_node(None, None, 1)
rv_out = rv_node.outputs[1]
rv_out.tag.test_value = 1
# `dtype` is respected
rv = RandomVariable("normal", signature="(),()->()", dtype="int32")
with config.change_flags(compute_test_value="off"):
rv_out = rv()
assert rv_out.dtype == "int32"
rv_out = rv(dtype="int64")
assert rv_out.dtype == "int64"
assert rv_out.dtype == dtype_1
with pytest.raises(NullTypeGradError):
grad(rv_out, [rv_node.inputs[0]])
with pytest.raises(
ValueError,
match="Cannot change the dtype of a normal RV from int32 to float32",
):
assert rv(dtype="float32").dtype == "float32"
def test_RandomVariable_bcast(strict_test_value_flags):
......@@ -238,70 +232,70 @@ def test_multivariate_rv_infer_static_shape():
assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3)
def test_vectorize_node():
def test_vectorize():
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
# Test without size
node = normal(vec).owner
new_inputs = node.inputs.copy()
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.inputs[3] is mat
out = normal(vec)
vect_node = vectorize_graph(out, {vec: mat}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.op.dist_params(vect_node)[0] is mat
# Test with size, new size provided
node = normal(vec, size=(3,)).owner
new_inputs = node.inputs.copy()
new_inputs[1] = (2, 3) # size
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert tuple(vect_node.inputs[1].eval()) == (2, 3)
assert vect_node.inputs[3] is mat
size = pt.as_tensor(np.array((3,), dtype="int64"))
out = normal(vec, size=size)
vect_node = vectorize_graph(out, {vec: mat, size: (2, 3)}).owner
assert isinstance(vect_node.op, NormalRV)
assert tuple(vect_node.op.size_param(vect_node).eval()) == (2, 3)
assert vect_node.op.dist_params(vect_node)[0] is mat
# Test with size, new size not provided
node = normal(vec, size=(3,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.inputs[3] is mat
out = normal(vec, size=(3,))
vect_node = vectorize_graph(out, {vec: mat}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.op.dist_params(vect_node)[0] is mat
assert tuple(
vect_node.inputs[1].eval({mat: np.zeros((2, 3), dtype=config.floatX)})
vect_node.op.size_param(vect_node).eval(
{mat: np.zeros((2, 3), dtype=config.floatX)}
)
) == (2, 3)
# Test parameter broadcasting
node = normal(vec).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
mu = vec
sigma = pt.as_tensor(np.array(1.0))
out = normal(mu, sigma)
new_mu = tensor("mu", shape=(10, 5))
new_sigma = tensor("sigma", shape=(10,))
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 5)
# Test parameter broadcasting with non-expanding size
node = normal(vec, size=(5,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
mu = vec
sigma = pt.as_tensor(np.array(1.0))
out = normal(mu, sigma, size=(5,))
new_mu = tensor("mu", shape=(10, 5))
new_sigma = tensor("sigma", shape=(10,))
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 5)
node = normal(vec, size=(5,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(1, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
mu = vec
sigma = pt.as_tensor(np.array(1.0))
out = normal(mu, sigma, size=(5,))
new_mu = tensor("mu", shape=(1, 5)) # mu
new_sigma = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 5)
# Test parameter broadcasting with expanding size
node = normal(vec, size=(2, 5)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
mu = vec
sigma = pt.as_tensor(np.array(1.0))
out = normal(mu, sigma, size=(2, 5))
new_mu = tensor("mu", shape=(1, 5))
new_sigma = tensor("sigma", shape=(10,))
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 2, 5)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论