提交 f091b86b authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba Dirichlet: Fix cas with discrete alpha parameter

Also handle core op caching key to support per-op invalidation
上级 409a4b2a
...@@ -216,14 +216,16 @@ def core_MvNormalRV(op, node): ...@@ -216,14 +216,16 @@ def core_MvNormalRV(op, node):
@numba_core_rv_funcify.register(ptr.DirichletRV) @numba_core_rv_funcify.register(ptr.DirichletRV)
def core_DirichletRV(op, node): def core_DirichletRV(op, node):
dtype = op.dtype
@numba_basic.numba_njit @numba_basic.numba_njit
def random_fn(rng, alpha): def random_fn(rng, alpha):
y = np.empty_like(alpha) y = np.empty_like(alpha, dtype=dtype)
for i in range(len(alpha)): for i in range(len(alpha)):
y[i] = rng.gamma(alpha[i], 1.0) y[i] = rng.gamma(alpha[i], 1.0)
return y / y.sum() return y / y.sum()
return random_fn return random_fn, 1
@numba_core_rv_funcify.register(ptr.GumbelRV) @numba_core_rv_funcify.register(ptr.GumbelRV)
...@@ -410,7 +412,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs ...@@ -410,7 +412,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
rv_op: RandomVariable = rv_node.op rv_op: RandomVariable = rv_node.op
try: try:
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) core_rv_fn_and_cache_key = numba_core_rv_funcify(rv_op, rv_node)
except NotImplementedError: except NotImplementedError:
py_impl = generate_fallback_impl(rv_op, node=rv_node, **kwargs) py_impl = generate_fallback_impl(rv_op, node=rv_node, **kwargs)
...@@ -420,6 +422,16 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs ...@@ -420,6 +422,16 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
return fallback_rv, None return fallback_rv, None
match core_rv_fn_and_cache_key:
case (core_rv_fn, (int() | None) as core_cache_key):
pass
case (_core_rv_fn, invalid_core_cache_key):
raise ValueError(
f"Invalid core_cache_key returned from numba_core_rv_funcify: {invalid_core_cache_key}. Must be int or None."
)
case core_rv_fn:
core_cache_key = "__None__"
size = rv_op.size_param(rv_node) size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node) dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
...@@ -469,16 +481,21 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs ...@@ -469,16 +481,21 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
return impl return impl
rv_op_props_dict = rv_op.props_dict() if hasattr(rv_op, "props_dict") else {} if core_cache_key is None:
random_rv_key_contents = ( # If the core RV can't be cached, then the whole RV can't be cached
type(op), random_rv_key = None # type: ignore[unreachable]
type(rv_op), else:
rv_op, rv_op_props_dict = rv_op.props_dict() if hasattr(rv_op, "props_dict") else {}
tuple(rv_op_props_dict.items()), random_rv_key_contents = (
size_len, type(op),
core_shape_len, type(rv_op),
inplace, rv_op,
input_bc_patterns, tuple(rv_op_props_dict.items()),
) size_len,
random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest() core_shape_len,
inplace,
input_bc_patterns,
core_cache_key,
)
random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest()
return random, random_rv_key return random, random_rv_key
...@@ -679,6 +679,16 @@ def test_DirichletRV(a, size, cm): ...@@ -679,6 +679,16 @@ def test_DirichletRV(a, size, cm):
assert np.allclose(res, exp_res, atol=1e-4) assert np.allclose(res, exp_res, atol=1e-4)
def test_dirichlet_discrete_alpha():
alpha = pt.lvector()
g = ptr.dirichlet(alpha, size=100)
fn = function([alpha], g, mode=numba_mode)
res = fn(np.array([1, 1, 1], dtype=np.int64))
assert res.dtype == np.float64
np.testing.assert_allclose(res.sum(-1), 1.0)
assert np.unique(res).size > 2 # Make sure we have more than just 0s and 1s
def test_rv_inside_ofg(): def test_rv_inside_ofg():
rng_np = np.random.default_rng(562) rng_np = np.random.default_rng(562)
rng = shared(rng_np) rng = shared(rng_np)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论