提交 f091b86b authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba Dirichlet: Fix cas with discrete alpha parameter

Also handle core op caching key to support per-op invalidation
上级 409a4b2a
......@@ -216,14 +216,16 @@ def core_MvNormalRV(op, node):
@numba_core_rv_funcify.register(ptr.DirichletRV)
def core_DirichletRV(op, node):
dtype = op.dtype
@numba_basic.numba_njit
def random_fn(rng, alpha):
y = np.empty_like(alpha)
y = np.empty_like(alpha, dtype=dtype)
for i in range(len(alpha)):
y[i] = rng.gamma(alpha[i], 1.0)
return y / y.sum()
return random_fn
return random_fn, 1
@numba_core_rv_funcify.register(ptr.GumbelRV)
......@@ -410,7 +412,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
rv_op: RandomVariable = rv_node.op
try:
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
core_rv_fn_and_cache_key = numba_core_rv_funcify(rv_op, rv_node)
except NotImplementedError:
py_impl = generate_fallback_impl(rv_op, node=rv_node, **kwargs)
......@@ -420,6 +422,16 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
return fallback_rv, None
match core_rv_fn_and_cache_key:
case (core_rv_fn, (int() | None) as core_cache_key):
pass
case (_core_rv_fn, invalid_core_cache_key):
raise ValueError(
f"Invalid core_cache_key returned from numba_core_rv_funcify: {invalid_core_cache_key}. Must be int or None."
)
case core_rv_fn:
core_cache_key = "__None__"
size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
......@@ -469,6 +481,10 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
return impl
if core_cache_key is None:
# If the core RV can't be cached, then the whole RV can't be cached
random_rv_key = None # type: ignore[unreachable]
else:
rv_op_props_dict = rv_op.props_dict() if hasattr(rv_op, "props_dict") else {}
random_rv_key_contents = (
type(op),
......@@ -479,6 +495,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
core_shape_len,
inplace,
input_bc_patterns,
core_cache_key,
)
random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest()
return random, random_rv_key
......@@ -679,6 +679,16 @@ def test_DirichletRV(a, size, cm):
assert np.allclose(res, exp_res, atol=1e-4)
def test_dirichlet_discrete_alpha():
alpha = pt.lvector()
g = ptr.dirichlet(alpha, size=100)
fn = function([alpha], g, mode=numba_mode)
res = fn(np.array([1, 1, 1], dtype=np.int64))
assert res.dtype == np.float64
np.testing.assert_allclose(res.sum(-1), 1.0)
assert np.unique(res).size > 2 # Make sure we have more than just 0s and 1s
def test_rv_inside_ofg():
rng_np = np.random.default_rng(562)
rng = shared(rng_np)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论