提交 a920c09f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Update RNG in numba Dirichlet test

上级 75789deb
...@@ -652,15 +652,11 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -652,15 +652,11 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
def test_DirichletRV(a, size, cm): def test_DirichletRV(a, size, cm):
a, a_val = a a, a_val = a
rng = shared(np.random.default_rng(29402)) rng = shared(np.random.default_rng(29402))
g = ptr.dirichlet(a, size=size, rng=rng) next_rng, g = ptr.dirichlet(a, size=size, rng=rng).owner.outputs
g_fn = function([a], g, mode=numba_mode) g_fn = function([a], g, mode=numba_mode, updates={rng: next_rng})
with cm: with cm:
all_samples = [] all_samples = [g_fn(a_val) for _ in range(1000)]
for i in range(1000):
samples = g_fn(a_val)
all_samples.append(samples)
exp_res = a_val / a_val.sum(-1) exp_res = a_val / a_val.sum(-1)
res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1))) res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1)))
assert np.allclose(res, exp_res, atol=1e-4) assert np.allclose(res, exp_res, atol=1e-4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论