提交 38c04c96 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add explicit expand_dims when building RandomVariable nodes

上级 591c47e6
......@@ -304,7 +304,6 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
"""JAX implementation of `ChoiceRV`."""
batch_ndim = op.batch_ndim(node)
a, *p, core_shape = op.dist_params(node)
a_core_ndim, *p_core_ndim, _ = op.ndims_params
if batch_ndim and a_core_ndim == 0:
......@@ -313,12 +312,6 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
"A default JAX rewrite should have materialized the implicit arange"
)
a_batch_ndim = a.type.ndim - a_core_ndim
if op.has_p_param:
[p] = p
[p_core_ndim] = p_core_ndim
p_batch_ndim = p.type.ndim - p_core_ndim
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
......@@ -328,7 +321,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
else:
a, core_shape = parameters
p = None
core_shape = tuple(np.asarray(core_shape))
core_shape = tuple(np.asarray(core_shape)[(0,) * batch_ndim])
if batch_ndim == 0:
sample = jax.random.choice(
......@@ -338,16 +331,16 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
else:
if size is None:
if p is None:
size = a.shape[:a_batch_ndim]
size = a.shape[:batch_ndim]
else:
size = jax.numpy.broadcast_shapes(
a.shape[:a_batch_ndim],
p.shape[:p_batch_ndim],
a.shape[:batch_ndim],
p.shape[:batch_ndim],
)
a = jax.numpy.broadcast_to(a, size + a.shape[a_batch_ndim:])
a = jax.numpy.broadcast_to(a, size + a.shape[batch_ndim:])
if p is not None:
p = jax.numpy.broadcast_to(p, size + p.shape[p_batch_ndim:])
p = jax.numpy.broadcast_to(p, size + p.shape[batch_ndim:])
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
......@@ -381,7 +374,6 @@ def jax_sample_fn_permutation(op, node):
"""JAX implementation of `PermutationRV`."""
batch_ndim = op.batch_ndim(node)
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
......@@ -389,11 +381,10 @@ def jax_sample_fn_permutation(op, node):
(x,) = parameters
if batch_ndim:
# jax.random.permutation has no concept of batch dims
x_core_shape = x.shape[x_batch_ndim:]
if size is None:
size = x.shape[:x_batch_ndim]
size = x.shape[:batch_ndim]
else:
x = jax.numpy.broadcast_to(x, size + x_core_shape)
x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:])
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
......
......@@ -347,7 +347,6 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
def numba_funcify_DirichletRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
alphas_ndim = op.dist_params(node)[0].type.ndim
neg_ind_shape_len = -alphas_ndim + 1
size_param = op.size_param(node)
size_len = (
None
......@@ -363,11 +362,6 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
samples_shape = alphas.shape
else:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
if (
0 < alphas.ndim - 1 <= len(size_tpl)
and size_tpl[neg_ind_shape_len:] != alphas.shape[:-1]
):
raise ValueError("Parameters shape and size do not match.")
samples_shape = size_tpl + alphas.shape[-1:]
res = np.empty(samples_shape, dtype=out_dtype)
......
......@@ -2002,6 +2002,11 @@ class ChoiceWithoutReplacement(RandomVariable):
a_shape = tuple(a.shape) if param_shapes is None else tuple(param_shapes[0])
a_batch_ndim = len(a_shape) - self.ndims_params[0]
a_core_shape = a_shape[a_batch_ndim:]
core_shape_ndim = core_shape.type.ndim
if core_shape_ndim > 1:
# Batch core shapes are only valid if homogeneous or broadcasted,
# as otherwise they would imply ragged choice arrays
core_shape = core_shape[(0,) * (core_shape_ndim - 1)]
return tuple(core_shape) + a_core_shape[1:]
def rng_fn(self, *params):
......@@ -2011,15 +2016,11 @@ class ChoiceWithoutReplacement(RandomVariable):
rng, a, core_shape, size = params
p = None
if core_shape.ndim > 1:
core_shape = core_shape[(0,) * (core_shape.ndim - 1)]
core_shape = tuple(core_shape)
# We don't have access to the node in rng_fn for easy computation of batch_ndim :(
a_batch_ndim = batch_ndim = a.ndim - self.ndims_params[0]
if p is not None:
p_batch_ndim = p.ndim - self.ndims_params[1]
batch_ndim = max(batch_ndim, p_batch_ndim)
size_ndim = 0 if size is None else len(size)
batch_ndim = max(batch_ndim, size_ndim)
batch_ndim = a.ndim - self.ndims_params[0]
if batch_ndim == 0:
# Numpy choice fails with size=() if a.ndim > 1 is batched
......@@ -2031,16 +2032,16 @@ class ChoiceWithoutReplacement(RandomVariable):
# Numpy choice doesn't have a concept of batch dims
if size is None:
if p is None:
size = a.shape[:a_batch_ndim]
size = a.shape[:batch_ndim]
else:
size = np.broadcast_shapes(
a.shape[:a_batch_ndim],
p.shape[:p_batch_ndim],
a.shape[:batch_ndim],
p.shape[:batch_ndim],
)
a = np.broadcast_to(a, size + a.shape[a_batch_ndim:])
a = np.broadcast_to(a, size + a.shape[batch_ndim:])
if p is not None:
p = np.broadcast_to(p, size + p.shape[p_batch_ndim:])
p = np.broadcast_to(p, size + p.shape[batch_ndim:])
a_indexed_shape = a.shape[len(size) + 1 :]
out = np.empty(size + core_shape + a_indexed_shape, dtype=a.dtype)
......@@ -2143,26 +2144,26 @@ class PermutationRV(RandomVariable):
def _supp_shape_from_params(self, dist_params, param_shapes=None):
[x] = dist_params
x_shape = tuple(x.shape if param_shapes is None else param_shapes[0])
if x.type.ndim == 0:
return (x,)
if self.ndims_params[0] == 0:
# Implicit arange, this is only valid for homogeneous arrays
# Otherwise it would imply a ragged permutation array.
return (x.ravel()[0],)
else:
batch_x_ndim = x.type.ndim - self.ndims_params[0]
return x_shape[batch_x_ndim:]
def rng_fn(self, rng, x, size):
# We don't have access to the node in rng_fn :(
x_batch_ndim = x.ndim - self.ndims_params[0]
batch_ndim = max(x_batch_ndim, 0 if size is None else len(size))
batch_ndim = x.ndim - self.ndims_params[0]
if batch_ndim:
# rng.permutation has no concept of batch dims
x_core_shape = x.shape[x_batch_ndim:]
if size is None:
size = x.shape[:x_batch_ndim]
size = x.shape[:batch_ndim]
else:
x = np.broadcast_to(x, size + x_core_shape)
x = np.broadcast_to(x, size + x.shape[batch_ndim:])
out = np.empty(size + x_core_shape, dtype=x.dtype)
out = np.empty(size + x.shape[batch_ndim:], dtype=x.dtype)
for idx in np.ndindex(size):
out[idx] = rng.permutation(x[idx])
return out
......
......@@ -9,7 +9,7 @@ import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.graph.replace import _vectorize_node
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import (
......@@ -359,6 +359,12 @@ class RandomVariable(Op):
inferred_shape = self._infer_shape(size, dist_params)
_, static_shape = infer_static_shape(inferred_shape)
dist_params = explicit_expand_dims(
dist_params,
self.ndims_params,
size_length=None if NoneConst.equals(size) else get_vector_length(size),
)
inputs = (rng, size, *dist_params)
out_type = TensorType(dtype=self.dtype, shape=static_shape)
outputs = (rng.type(), out_type())
......@@ -459,22 +465,14 @@ def vectorize_random_variable(
None if isinstance(old_size.type, NoneTypeT) else get_vector_length(old_size)
)
original_expanded_dist_params = explicit_expand_dims(
original_dist_params, op.ndims_params, len_old_size
)
# We call vectorize_graph to automatically handle any new explicit expand_dims
dist_params = vectorize_graph(
original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
)
new_ndim = dist_params[0].type.ndim - original_expanded_dist_params[0].type.ndim
if new_ndim and len_old_size and equal_computations([old_size], [size]):
if len_old_size and equal_computations([old_size], [size]):
# If the original RV had a size variable and a new one has not been provided,
# we need to define a new size as the concatenation of the original size dimensions
# and the novel ones implied by new broadcasted batched parameters dimensions.
broadcasted_batch_shape = compute_batch_shape(dist_params, op.ndims_params)
new_size_dims = broadcasted_batch_shape[:new_ndim]
size = concatenate([new_size_dims, size])
new_ndim = dist_params[0].type.ndim - original_dist_params[0].type.ndim
if new_ndim >= 0:
new_size = compute_batch_shape(dist_params, ndims_params=op.ndims_params)
new_size_dims = new_size[:new_ndim]
size = concatenate([new_size_dims, size])
return op.make_node(rng, size, *dist_params)
import re
from pytensor.compile import optdb
from pytensor.graph import Constant
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.tensor import abs as abs_t
......@@ -159,12 +160,17 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
return None
rng, size, a_scalar_param, *other_params = node.inputs
if a_scalar_param.type.ndim > 0:
if not all(a_scalar_param.type.broadcastable):
# Automatic vectorization could have made this parameter batched,
# there is no nice way to materialize a batched arange
return None
a_vector_param = arange(a_scalar_param)
# We need to try and do an eager squeeze here because arange will fail in jax
# if there is an array leading to it, even if it's constant
if isinstance(a_scalar_param, Constant):
a_scalar_param = a_scalar_param.data
a_vector_param = arange(a_scalar_param.squeeze())
new_props_dict = op._props_dict().copy()
# Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)"
# I.e., we substitute the first `()` by `(a)`
......
......@@ -28,6 +28,9 @@ from tests.tensor.random.test_basic import (
rng = np.random.default_rng(42849)
@pytest.mark.xfail(
reason="Most RVs are not working correctly with explicit expand_dims"
)
@pytest.mark.parametrize(
"rv_op, dist_args, size",
[
......@@ -388,6 +391,7 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
)
@pytest.mark.xfail(reason="Test is not working correctly with explicit expand_dims")
@pytest.mark.parametrize(
"rv_op, dist_args, base_size, cdf_name, params_conv",
[
......@@ -633,7 +637,7 @@ def test_CategoricalRV(dist_args, size, cm):
),
),
(10, 4),
pytest.raises(ValueError, match="Parameters shape.*"),
pytest.raises(ValueError, match="operands could not be broadcast together"),
),
],
)
......@@ -658,6 +662,7 @@ def test_DirichletRV(a, size, cm):
assert np.allclose(res, exp_res, atol=1e-4)
@pytest.mark.xfail(reason="RandomState is not aligned with explicit expand_dims")
def test_RandomState_updates():
rng = shared(np.random.RandomState(1))
rng_new = shared(np.random.RandomState(2))
......
......@@ -796,13 +796,21 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
rng,
)
def is_subtensor_or_dimshuffle_subtensor(inp) -> bool:
subtensor_ops = Subtensor | AdvancedSubtensor | AdvancedSubtensor1
if isinstance(inp.owner.op, subtensor_ops):
return True
if isinstance(inp.owner.op, DimShuffle):
return isinstance(inp.owner.inputs[0].owner.op, subtensor_ops)
return False
if lifted:
assert isinstance(new_out.owner.op, RandomVariable)
assert all(
isinstance(i.owner.op, AdvancedSubtensor | AdvancedSubtensor1 | Subtensor)
is_subtensor_or_dimshuffle_subtensor(i)
for i in new_out.owner.op.dist_params(new_out.owner)
if i.owner
)
), new_out.dprint(depth=3, print_type=True)
else:
assert isinstance(
new_out.owner.op, AdvancedSubtensor | AdvancedSubtensor1 | Subtensor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论