提交 fc0452f1 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Thomas Wiecki

Add `StudentTRV` JAX implementation

上级 383d4efe
......@@ -208,6 +208,24 @@ def jax_sample_fn_exponential(op):
return sample_fn
@jax_sample_fn.register(aer.StudentTRV)
def jax_sample_fn_t(op):
"""JAX implementation of `StudentTRV`."""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
(
df,
loc,
scale,
) = parameters
sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.ChoiceRV)
def jax_funcify_choice(op):
"""JAX implementation of `ChoiceRV`."""
......
......@@ -205,6 +205,26 @@ def test_random_updates(rng_ctor):
"randint",
lambda *args: args,
),
(
aer.t,
[
set_test_value(
at.dscalar(),
np.array(2.0, dtype=np.float64),
),
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"t",
lambda *args: args,
),
(
aer.uniform,
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论