提交 f49a6c51 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba fallback non-implemented RVs

上级 abaf1239
......@@ -20,6 +20,7 @@ from pytensor.link.utils import (
)
from pytensor.scalar.basic import ScalarType
from pytensor.sparse import SparseTensorType
from pytensor.tensor.random.type import RandomGeneratorType
from pytensor.tensor.type import TensorType
from pytensor.tensor.utils import hash_from_ndarray
......@@ -129,8 +130,8 @@ def get_numba_type(
return CSRMatrixType(numba_dtype)
if pytensor_type.format == "csc":
return CSCMatrixType(numba_dtype)
raise NotImplementedError()
elif isinstance(pytensor_type, RandomGeneratorType):
return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType")
else:
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
......
......@@ -16,6 +16,7 @@ from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
direct_cast,
generate_fallback_impl,
numba_funcify,
register_funcify_and_cache_key,
)
......@@ -406,13 +407,24 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
[rv_node] = op.fgraph.apply_nodes
rv_op: RandomVariable = rv_node.op
try:
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
except NotImplementedError:
py_impl = generate_fallback_impl(rv_op, node=rv_node, **kwargs)
@numba_basic.numba_njit
def fallback_rv(_core_shape, *args):
return py_impl(*args)
return fallback_rv, None
size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
core_shape_len = get_vector_length(core_shape)
inplace = rv_op.inplace
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
nin = 1 + len(dist_params) # rng + params
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)
......
......@@ -257,7 +257,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
],
pt.as_tensor([3, 2]),
),
pytest.param(
(
ptr.hypergeometric,
[
(
......@@ -274,7 +274,6 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
),
],
pt.as_tensor([3, 2]),
marks=pytest.mark.xfail, # Not implemented
),
(
ptr.wald,
......@@ -722,3 +721,34 @@ def test_repeated_args():
final_node = fn.maker.fgraph.outputs[0].owner
assert isinstance(final_node.op, RandomVariableWithCoreShape)
assert final_node.inputs[-2] is final_node.inputs[-1]
def test_rv_fallback():
"""Test that random variables can fallback to object mode."""
class CustomRV(ptr.RandomVariable):
name = "custom"
signature = "()->()"
dtype = "float64"
def rng_fn(self, rng, value, size=None):
# Just return the value plus a random number
return value + rng.standard_normal(size=size)
custom_rv = CustomRV()
rng = shared(np.random.default_rng(123))
size = pt.scalar("size", dtype=int)
next_rng, x = custom_rv(np.pi, size=(size,), rng=rng).owner.outputs
fn = function([size], x, updates={rng: next_rng}, mode="NUMBA")
result1 = fn(1)
result2 = fn(1)
assert result1.shape == (1,)
assert result1 != result2
large_sample = fn(1000)
assert large_sample.shape == (1000,)
np.testing.assert_allclose(large_sample.mean(), np.pi, rtol=1e-2)
np.testing.assert_allclose(large_sample.std(), 1, rtol=1e-2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论