Unverified 提交 4eded292 authored 作者: Thomas Wiecki's avatar Thomas Wiecki 提交者: GitHub

🔄 From Aesara: 1362: "Add `HalfNormalRV` JAX implementation" (#129)

* Add `HalfNormalRV` JAX implementation (#1362) Co-authored-by: 's avatartheorashid <theoaorashid@gmail.com> Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com>
上级 e2ae99e1
...@@ -260,6 +260,23 @@ def jax_sample_fn_t(op): ...@@ -260,6 +260,23 @@ def jax_sample_fn_t(op):
return sample_fn return sample_fn
@jax_sample_fn.register(aer.HalfNormalRV)
def jax_sample_fn_halfnormal(op):
"""JAX implementation of `HalfNormalRV`."""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters
sample = (
loc + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype)) * scale
)
rng["jax_state"] = rng_key
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.ChoiceRV) @jax_sample_fn.register(aer.ChoiceRV)
def jax_funcify_choice(op): def jax_funcify_choice(op):
"""JAX implementation of `ChoiceRV`.""" """JAX implementation of `ChoiceRV`."""
......
...@@ -280,6 +280,22 @@ def test_random_updates(rng_ctor): ...@@ -280,6 +280,22 @@ def test_random_updates(rng_ctor):
"uniform", "uniform",
lambda *args: args, lambda *args: args,
), ),
(
aer.halfnormal,
[
set_test_value(
at.dvector(),
np.array([-1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1000.0, dtype=np.float64),
),
],
(2,),
"halfnorm",
lambda *args: args,
),
], ],
) )
def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv): def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论