提交 fc0452f1 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Thomas Wiecki

Add `StudentTRV` JAX implementation

上级 383d4efe
...@@ -208,6 +208,24 @@ def jax_sample_fn_exponential(op): ...@@ -208,6 +208,24 @@ def jax_sample_fn_exponential(op):
return sample_fn return sample_fn
@jax_sample_fn.register(aer.StudentTRV)
def jax_sample_fn_t(op):
"""JAX implementation of `StudentTRV`."""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
(
df,
loc,
scale,
) = parameters
sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.ChoiceRV) @jax_sample_fn.register(aer.ChoiceRV)
def jax_funcify_choice(op): def jax_funcify_choice(op):
"""JAX implementation of `ChoiceRV`.""" """JAX implementation of `ChoiceRV`."""
......
...@@ -205,6 +205,26 @@ def test_random_updates(rng_ctor): ...@@ -205,6 +205,26 @@ def test_random_updates(rng_ctor):
"randint", "randint",
lambda *args: args, lambda *args: args,
), ),
(
aer.t,
[
set_test_value(
at.dscalar(),
np.array(2.0, dtype=np.float64),
),
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"t",
lambda *args: args,
),
( (
aer.uniform, aer.uniform,
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论