提交 dcd24a36 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement JAX dispatch for IntegersRV

上级 bf519076
...@@ -179,6 +179,7 @@ def jax_sample_fn_no_dtype(op): ...@@ -179,6 +179,7 @@ def jax_sample_fn_no_dtype(op):
@jax_sample_fn.register(aer.RandIntRV) @jax_sample_fn.register(aer.RandIntRV)
@jax_sample_fn.register(aer.IntegersRV)
@jax_sample_fn.register(aer.UniformRV) @jax_sample_fn.register(aer.UniformRV)
def jax_sample_fn_uniform(op): def jax_sample_fn_uniform(op):
"""JAX implementation of random variables with uniform density. """JAX implementation of random variables with uniform density.
...@@ -188,6 +189,9 @@ def jax_sample_fn_uniform(op): ...@@ -188,6 +189,9 @@ def jax_sample_fn_uniform(op):
""" """
name = op.name name = op.name
# IntegersRV is equivalent to RandintRV
if isinstance(op, aer.IntegersRV):
name = "randint"
jax_op = getattr(jax.random, name) jax_op = getattr(jax.random, name)
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
......
...@@ -237,6 +237,22 @@ def test_random_updates(rng_ctor): ...@@ -237,6 +237,22 @@ def test_random_updates(rng_ctor):
"randint", "randint",
lambda *args: args, lambda *args: args,
), ),
(
aer.integers,
[
set_test_value(
at.lscalar(),
np.array(0, dtype=np.int64),
),
set_test_value( # high-value necessary since test on cdf
at.lscalar(),
np.array(1000, dtype=np.int64),
),
],
(),
"randint",
lambda *args: args,
),
( (
aer.standard_normal, aer.standard_normal,
[], [],
...@@ -376,7 +392,11 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c ...@@ -376,7 +392,11 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
The parameters passed to the op. The parameters passed to the op.
""" """
rng = shared(np.random.RandomState(29402)) if rv_op is aer.integers:
# Integers only accepts Generator, not RandomState
rng = shared(np.random.default_rng(29402))
else:
rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng) g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
g_fn = function(dist_params, g, mode=jax_mode) g_fn = function(dist_params, g, mode=jax_mode)
samples = g_fn( samples = g_fn(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论