提交 a920c09f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Update RNG in numba Dirichlet test

上级 75789deb
......@@ -652,15 +652,11 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
def test_DirichletRV(a, size, cm):
a, a_val = a
rng = shared(np.random.default_rng(29402))
g = ptr.dirichlet(a, size=size, rng=rng)
g_fn = function([a], g, mode=numba_mode)
next_rng, g = ptr.dirichlet(a, size=size, rng=rng).owner.outputs
g_fn = function([a], g, mode=numba_mode, updates={rng: next_rng})
with cm:
all_samples = []
for i in range(1000):
samples = g_fn(a_val)
all_samples.append(samples)
all_samples = [g_fn(a_val) for _ in range(1000)]
exp_res = a_val / a_val.sum(-1)
res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1)))
assert np.allclose(res, exp_res, atol=1e-4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论