提交 24a2234b authored 作者: ricardoV94's avatar ricardoV94 提交者: Jesse Grabowski

Numba dispatch of StudentT

上级 7b0a3924
...@@ -102,6 +102,15 @@ def numba_core_BernoulliRV(op, node): ...@@ -102,6 +102,15 @@ def numba_core_BernoulliRV(op, node):
return random return random
@numba_core_rv_funcify.register(ptr.StudentTRV)
def numba_core_StudentTRV(op, node):
@numba_basic.numba_njit
def random_fn(rng, df, loc, scale):
return loc + scale * rng.standard_t(df)
return random_fn
@numba_core_rv_funcify.register(ptr.HalfNormalRV) @numba_core_rv_funcify.register(ptr.HalfNormalRV)
def numba_core_HalfNormalRV(op, node): def numba_core_HalfNormalRV(op, node):
@numba_basic.numba_njit @numba_basic.numba_njit
......
...@@ -592,6 +592,23 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ...@@ -592,6 +592,23 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
"gumbel_r", "gumbel_r",
lambda *args: args, lambda *args: args,
), ),
(
ptr.t,
[
(pt.scalar(), np.array(np.e, dtype=np.float64)),
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
(
pt.dscalar(),
np.array(np.pi, dtype=np.float64),
),
],
(2,),
"t",
lambda *args: args,
),
], ],
) )
def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv): def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论