提交 1827703c authored 作者: juanitorduz's avatar juanitorduz 提交者: Thomas Wiecki

jax lognormal

上级 99510c34
......@@ -277,3 +277,18 @@ def jax_sample_fn_permutation(op):
return (rng, sample)
return sample_fn
@jax_sample_fn.register(aer.LogNormalRV)
def jax_sample_fn_lognormal(op):
"""JAX implementation of `LogNormalRV`."""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
loc, scale = parameters
sample = loc + jax.random.normal(rng_key, size, dtype) * scale
sample_exp = jax.numpy.exp(sample)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample_exp)
return sample_fn
......@@ -165,6 +165,22 @@ def test_random_updates(rng_ctor):
"logistic",
lambda *args: args,
),
(
aer.lognormal,
[
set_test_value(
at.lvector(),
np.array([0, 0], dtype=np.int64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"lognorm",
lambda *args: args,
),
(
aer.normal,
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论