提交 df32683c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add RandomVariable Op helpers to retrieve rng, size, and dist_params from a node, for readability

上级 3e9c6a4f
...@@ -88,7 +88,7 @@ def jax_typify_Generator(rng, **kwargs): ...@@ -88,7 +88,7 @@ def jax_typify_Generator(rng, **kwargs):
@jax_funcify.register(ptr.RandomVariable) @jax_funcify.register(ptr.RandomVariable)
def jax_funcify_RandomVariable(op, node, **kwargs): def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
"""JAX implementation of random variables.""" """JAX implementation of random variables."""
rv = node.outputs[1] rv = node.outputs[1]
out_dtype = rv.type.dtype out_dtype = rv.type.dtype
...@@ -101,7 +101,7 @@ def jax_funcify_RandomVariable(op, node, **kwargs): ...@@ -101,7 +101,7 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
if None in static_size: if None in static_size:
# Sometimes size can be constant folded during rewrites, # Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types # without the RandomVariable node being updated with new static types
size_param = node.inputs[1] size_param = op.size_param(node)
if isinstance(size_param, Constant): if isinstance(size_param, Constant):
size_tuple = tuple(size_param.data) size_tuple = tuple(size_param.data)
# PyTensor uses empty size to represent size = None # PyTensor uses empty size to represent size = None
...@@ -304,11 +304,11 @@ def jax_sample_fn_t(op, node): ...@@ -304,11 +304,11 @@ def jax_sample_fn_t(op, node):
@jax_sample_fn.register(ptr.ChoiceWithoutReplacement) @jax_sample_fn.register(ptr.ChoiceWithoutReplacement)
def jax_funcify_choice(op, node): def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
"""JAX implementation of `ChoiceRV`.""" """JAX implementation of `ChoiceRV`."""
batch_ndim = op.batch_ndim(node) batch_ndim = op.batch_ndim(node)
a, *p, core_shape = node.inputs[3:] a, *p, core_shape = op.dist_params(node)
a_core_ndim, *p_core_ndim, _ = op.ndims_params a_core_ndim, *p_core_ndim, _ = op.ndims_params
if batch_ndim and a_core_ndim == 0: if batch_ndim and a_core_ndim == 0:
......
...@@ -96,11 +96,14 @@ def make_numba_random_fn(node, np_random_func): ...@@ -96,11 +96,14 @@ def make_numba_random_fn(node, np_random_func):
The functions generated here add parameter broadcasting and the ``size`` The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions. argument to the Numba-supported scalar ``np.random`` functions.
""" """
if not isinstance(node.inputs[0].type, RandomStateType): op: ptr.RandomVariable = node.op
rng_param = op.rng_param(node)
if not isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `Generator`s") raise TypeError("Numba does not support NumPy `Generator`s")
tuple_size = int(get_vector_length(node.inputs[1])) tuple_size = int(get_vector_length(op.size_param(node)))
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:]) dist_params = op.dist_params(node)
size_dims = tuple_size - max(i.ndim for i in dist_params)
# Make a broadcast-capable version of the Numba supported scalar sampling # Make a broadcast-capable version of the Numba supported scalar sampling
# function # function
...@@ -126,7 +129,7 @@ def make_numba_random_fn(node, np_random_func): ...@@ -126,7 +129,7 @@ def make_numba_random_fn(node, np_random_func):
) )
bcast_fn_input_names = ", ".join( bcast_fn_input_names = ", ".join(
[unique_names(i, force_unique=True) for i in node.inputs[3:]] [unique_names(i, force_unique=True) for i in dist_params]
) )
bcast_fn_global_env = { bcast_fn_global_env = {
"np_random_func": np_random_func, "np_random_func": np_random_func,
...@@ -143,7 +146,7 @@ def {bcast_fn_name}({bcast_fn_input_names}): ...@@ -143,7 +146,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
) )
random_fn_input_names = ", ".join( random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]] ["rng", "size", "dtype"] + [unique_names(i) for i in dist_params]
) )
# Now, create a Numba JITable function that implements the `size` parameter # Now, create a Numba JITable function that implements the `size` parameter
...@@ -244,7 +247,8 @@ def create_numba_random_fn( ...@@ -244,7 +247,8 @@ def create_numba_random_fn(
suffix_sep="_", suffix_sep="_",
) )
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]] dist_params = op.dist_params(node)
np_names = [unique_names(i, force_unique=True) for i in dist_params]
np_input_names = ", ".join(np_names) np_input_names = ", ".join(np_names)
np_random_fn_src = f""" np_random_fn_src = f"""
@numba_vectorize @numba_vectorize
...@@ -300,9 +304,9 @@ def numba_funcify_BernoulliRV(op, node, **kwargs): ...@@ -300,9 +304,9 @@ def numba_funcify_BernoulliRV(op, node, **kwargs):
@numba_funcify.register(ptr.CategoricalRV) @numba_funcify.register(ptr.CategoricalRV)
def numba_funcify_CategoricalRV(op, node, **kwargs): def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype out_dtype = node.outputs[1].type.numpy_dtype
size_len = int(get_vector_length(node.inputs[1])) size_len = int(get_vector_length(op.size_param(node)))
p_ndim = node.inputs[-1].ndim p_ndim = node.inputs[-1].ndim
@numba_basic.numba_njit @numba_basic.numba_njit
...@@ -331,9 +335,9 @@ def numba_funcify_CategoricalRV(op, node, **kwargs): ...@@ -331,9 +335,9 @@ def numba_funcify_CategoricalRV(op, node, **kwargs):
@numba_funcify.register(ptr.DirichletRV) @numba_funcify.register(ptr.DirichletRV)
def numba_funcify_DirichletRV(op, node, **kwargs): def numba_funcify_DirichletRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype out_dtype = node.outputs[1].type.numpy_dtype
alphas_ndim = node.inputs[3].type.ndim alphas_ndim = op.dist_params(node)[0].type.ndim
neg_ind_shape_len = -alphas_ndim + 1 neg_ind_shape_len = -alphas_ndim + 1
size_len = int(get_vector_length(node.inputs[1])) size_len = int(get_vector_length(op.size_param(node)))
if alphas_ndim > 1: if alphas_ndim > 1:
...@@ -400,9 +404,9 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs): ...@@ -400,9 +404,9 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
@numba_funcify.register(ptr.PermutationRV) @numba_funcify.register(ptr.PermutationRV)
def numba_funcify_permutation(op, node, **kwargs): def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
# PyTensor uses size=() to represent size=None # PyTensor uses size=() to represent size=None
size_is_none = node.inputs[1].type.shape == (0,) size_is_none = op.size_param(node).type.shape == (0,)
batch_ndim = op.batch_ndim(node) batch_ndim = op.batch_ndim(node)
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0] x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
......
...@@ -372,6 +372,18 @@ class RandomVariable(Op): ...@@ -372,6 +372,18 @@ class RandomVariable(Op):
def batch_ndim(self, node: Apply) -> int: def batch_ndim(self, node: Apply) -> int:
return cast(int, node.default_output().type.ndim - self.ndim_supp) return cast(int, node.default_output().type.ndim - self.ndim_supp)
def rng_param(self, node) -> Variable:
"""Return the node input corresponding to the rng"""
return node.inputs[0]
def size_param(self, node) -> Variable:
"""Return the node input corresponding to the size"""
return node.inputs[1]
def dist_params(self, node) -> Sequence[Variable]:
"""Return the node inpust corresponding to dist params"""
return node.inputs[3:]
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
rng_var_out, smpl_out = outputs rng_var_out, smpl_out = outputs
......
...@@ -255,7 +255,7 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -255,7 +255,7 @@ def local_subtensor_rv_lift(fgraph, node):
return False return False
# Check that indexing does not act on support dims # Check that indexing does not act on support dims
batch_ndims = rv.ndim - rv_op.ndim_supp batch_ndims = rv_op.batch_ndim(rv_node)
# We decompose the boolean indexes, which makes it clear whether they act on support dims or not # We decompose the boolean indexes, which makes it clear whether they act on support dims or not
non_bool_indices = tuple( non_bool_indices = tuple(
chain.from_iterable( chain.from_iterable(
......
...@@ -111,9 +111,9 @@ def test_inplace_rewrites(rv_op): ...@@ -111,9 +111,9 @@ def test_inplace_rewrites(rv_op):
assert new_op._props_dict() == (op._props_dict() | {"inplace": True}) assert new_op._props_dict() == (op._props_dict() | {"inplace": True})
assert all( assert all(
np.array_equal(a.data, b.data) np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:]) for a, b in zip(new_op.dist_params(new_node), op.dist_params(node))
) )
assert np.array_equal(new_out.owner.inputs[1].data, []) assert np.array_equal(new_op.size_param(new_node).data, op.size_param(node).data)
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
...@@ -400,7 +400,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -400,7 +400,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
assert new_out.owner.op == dist_op assert new_out.owner.op == dist_op
assert all( assert all(
isinstance(i.owner.op, DimShuffle) isinstance(i.owner.op, DimShuffle)
for i in new_out.owner.inputs[3:] for i in new_out.owner.op.dist_params(new_out.owner)
if i.owner if i.owner
) )
else: else:
...@@ -793,7 +793,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): ...@@ -793,7 +793,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
assert isinstance(new_out.owner.op, RandomVariable) assert isinstance(new_out.owner.op, RandomVariable)
assert all( assert all(
isinstance(i.owner.op, AdvancedSubtensor | AdvancedSubtensor1 | Subtensor) isinstance(i.owner.op, AdvancedSubtensor | AdvancedSubtensor1 | Subtensor)
for i in new_out.owner.inputs[3:] for i in new_out.owner.op.dist_params(new_out.owner)
if i.owner if i.owner
) )
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论