提交 df32683c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add RandomVariable Op helpers to retrieve rng, size, and dist_params from a node, for readability

上级 3e9c6a4f
......@@ -88,7 +88,7 @@ def jax_typify_Generator(rng, **kwargs):
@jax_funcify.register(ptr.RandomVariable)
def jax_funcify_RandomVariable(op, node, **kwargs):
def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
"""JAX implementation of random variables."""
rv = node.outputs[1]
out_dtype = rv.type.dtype
......@@ -101,7 +101,7 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
if None in static_size:
# Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types
size_param = node.inputs[1]
size_param = op.size_param(node)
if isinstance(size_param, Constant):
size_tuple = tuple(size_param.data)
# PyTensor uses empty size to represent size = None
......@@ -304,11 +304,11 @@ def jax_sample_fn_t(op, node):
@jax_sample_fn.register(ptr.ChoiceWithoutReplacement)
def jax_funcify_choice(op, node):
def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
"""JAX implementation of `ChoiceRV`."""
batch_ndim = op.batch_ndim(node)
a, *p, core_shape = node.inputs[3:]
a, *p, core_shape = op.dist_params(node)
a_core_ndim, *p_core_ndim, _ = op.ndims_params
if batch_ndim and a_core_ndim == 0:
......
......@@ -96,11 +96,14 @@ def make_numba_random_fn(node, np_random_func):
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
if not isinstance(node.inputs[0].type, RandomStateType):
op: ptr.RandomVariable = node.op
rng_param = op.rng_param(node)
if not isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `Generator`s")
tuple_size = int(get_vector_length(node.inputs[1]))
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
tuple_size = int(get_vector_length(op.size_param(node)))
dist_params = op.dist_params(node)
size_dims = tuple_size - max(i.ndim for i in dist_params)
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
......@@ -126,7 +129,7 @@ def make_numba_random_fn(node, np_random_func):
)
bcast_fn_input_names = ", ".join(
[unique_names(i, force_unique=True) for i in node.inputs[3:]]
[unique_names(i, force_unique=True) for i in dist_params]
)
bcast_fn_global_env = {
"np_random_func": np_random_func,
......@@ -143,7 +146,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
)
random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]
["rng", "size", "dtype"] + [unique_names(i) for i in dist_params]
)
# Now, create a Numba JITable function that implements the `size` parameter
......@@ -244,7 +247,8 @@ def create_numba_random_fn(
suffix_sep="_",
)
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]]
dist_params = op.dist_params(node)
np_names = [unique_names(i, force_unique=True) for i in dist_params]
np_input_names = ", ".join(np_names)
np_random_fn_src = f"""
@numba_vectorize
......@@ -300,9 +304,9 @@ def numba_funcify_BernoulliRV(op, node, **kwargs):
@numba_funcify.register(ptr.CategoricalRV)
def numba_funcify_CategoricalRV(op, node, **kwargs):
def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
size_len = int(get_vector_length(node.inputs[1]))
size_len = int(get_vector_length(op.size_param(node)))
p_ndim = node.inputs[-1].ndim
@numba_basic.numba_njit
......@@ -331,9 +335,9 @@ def numba_funcify_CategoricalRV(op, node, **kwargs):
@numba_funcify.register(ptr.DirichletRV)
def numba_funcify_DirichletRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
alphas_ndim = node.inputs[3].type.ndim
alphas_ndim = op.dist_params(node)[0].type.ndim
neg_ind_shape_len = -alphas_ndim + 1
size_len = int(get_vector_length(node.inputs[1]))
size_len = int(get_vector_length(op.size_param(node)))
if alphas_ndim > 1:
......@@ -400,9 +404,9 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
@numba_funcify.register(ptr.PermutationRV)
def numba_funcify_permutation(op, node, **kwargs):
def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
# PyTensor uses size=() to represent size=None
size_is_none = node.inputs[1].type.shape == (0,)
size_is_none = op.size_param(node).type.shape == (0,)
batch_ndim = op.batch_ndim(node)
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
......
......@@ -372,6 +372,18 @@ class RandomVariable(Op):
def batch_ndim(self, node: Apply) -> int:
return cast(int, node.default_output().type.ndim - self.ndim_supp)
def rng_param(self, node) -> Variable:
"""Return the node input corresponding to the rng"""
return node.inputs[0]
def size_param(self, node) -> Variable:
"""Return the node input corresponding to the size"""
return node.inputs[1]
def dist_params(self, node) -> Sequence[Variable]:
"""Return the node inpust corresponding to dist params"""
return node.inputs[3:]
def perform(self, node, inputs, outputs):
rng_var_out, smpl_out = outputs
......
......@@ -255,7 +255,7 @@ def local_subtensor_rv_lift(fgraph, node):
return False
# Check that indexing does not act on support dims
batch_ndims = rv.ndim - rv_op.ndim_supp
batch_ndims = rv_op.batch_ndim(rv_node)
# We decompose the boolean indexes, which makes it clear whether they act on support dims or not
non_bool_indices = tuple(
chain.from_iterable(
......
......@@ -111,9 +111,9 @@ def test_inplace_rewrites(rv_op):
assert new_op._props_dict() == (op._props_dict() | {"inplace": True})
assert all(
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
for a, b in zip(new_op.dist_params(new_node), op.dist_params(node))
)
assert np.array_equal(new_out.owner.inputs[1].data, [])
assert np.array_equal(new_op.size_param(new_node).data, op.size_param(node).data)
@config.change_flags(compute_test_value="raise")
......@@ -400,7 +400,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
assert new_out.owner.op == dist_op
assert all(
isinstance(i.owner.op, DimShuffle)
for i in new_out.owner.inputs[3:]
for i in new_out.owner.op.dist_params(new_out.owner)
if i.owner
)
else:
......@@ -793,7 +793,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
assert isinstance(new_out.owner.op, RandomVariable)
assert all(
isinstance(i.owner.op, AdvancedSubtensor | AdvancedSubtensor1 | Subtensor)
for i in new_out.owner.inputs[3:]
for i in new_out.owner.op.dist_params(new_out.owner)
if i.owner
)
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论