提交 a110e82b authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Thomas Wiecki

Add `StandardNormalRV` JAX implementation

上级 58df5640
......@@ -114,6 +114,7 @@ def jax_sample_fn_generic(op):
@jax_sample_fn.register(aer.LaplaceRV)
@jax_sample_fn.register(aer.LogisticRV)
@jax_sample_fn.register(aer.NormalRV)
@jax_sample_fn.register(aer.StandardNormalRV)
def jax_sample_fn_loc_scale(op):
"""JAX implementation of random variables in the loc-scale families.
......
......@@ -221,6 +221,13 @@ def test_random_updates(rng_ctor):
"randint",
lambda *args: args,
),
(
aer.standard_normal,
[],
(2,),
"norm",
lambda *args: args,
),
(
aer.t,
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论