提交 58df5640 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Thomas Wiecki

Add `GumbelRV` JAX implementation

上级 fc0452f1
......@@ -110,6 +110,7 @@ def jax_sample_fn_generic(op):
@jax_sample_fn.register(aer.CauchyRV)
@jax_sample_fn.register(aer.GumbelRV)
@jax_sample_fn.register(aer.LaplaceRV)
@jax_sample_fn.register(aer.LogisticRV)
@jax_sample_fn.register(aer.NormalRV)
......
......@@ -123,6 +123,22 @@ def test_random_updates(rng_ctor):
"gamma",
lambda a, b: (a, 0.0, b),
),
(
aer.gumbel,
[
set_test_value(
at.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"gumbel_r",
lambda *args: args,
),
(
aer.laplace,
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论