提交 f49a6c51 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba fallback non-implemented RVs

上级 abaf1239
...@@ -20,6 +20,7 @@ from pytensor.link.utils import ( ...@@ -20,6 +20,7 @@ from pytensor.link.utils import (
) )
from pytensor.scalar.basic import ScalarType from pytensor.scalar.basic import ScalarType
from pytensor.sparse import SparseTensorType from pytensor.sparse import SparseTensorType
from pytensor.tensor.random.type import RandomGeneratorType
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.utils import hash_from_ndarray from pytensor.tensor.utils import hash_from_ndarray
...@@ -129,8 +130,8 @@ def get_numba_type( ...@@ -129,8 +130,8 @@ def get_numba_type(
return CSRMatrixType(numba_dtype) return CSRMatrixType(numba_dtype)
if pytensor_type.format == "csc": if pytensor_type.format == "csc":
return CSCMatrixType(numba_dtype) return CSCMatrixType(numba_dtype)
elif isinstance(pytensor_type, RandomGeneratorType):
raise NotImplementedError() return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType")
else: else:
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
......
...@@ -16,6 +16,7 @@ from pytensor.graph.op import Op ...@@ -16,6 +16,7 @@ from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
direct_cast, direct_cast,
generate_fallback_impl,
numba_funcify, numba_funcify,
register_funcify_and_cache_key, register_funcify_and_cache_key,
) )
...@@ -406,13 +407,24 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs ...@@ -406,13 +407,24 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
[rv_node] = op.fgraph.apply_nodes [rv_node] = op.fgraph.apply_nodes
rv_op: RandomVariable = rv_node.op rv_op: RandomVariable = rv_node.op
try:
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
except NotImplementedError:
py_impl = generate_fallback_impl(rv_op, node=rv_node, **kwargs)
@numba_basic.numba_njit
def fallback_rv(_core_shape, *args):
return py_impl(*args)
return fallback_rv, None
size = rv_op.size_param(rv_node) size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node) dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
core_shape_len = get_vector_length(core_shape) core_shape_len = get_vector_length(core_shape)
inplace = rv_op.inplace inplace = rv_op.inplace
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
nin = 1 + len(dist_params) # rng + params nin = 1 + len(dist_params) # rng + params
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)
......
...@@ -257,7 +257,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -257,7 +257,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
], ],
pt.as_tensor([3, 2]), pt.as_tensor([3, 2]),
), ),
pytest.param( (
ptr.hypergeometric, ptr.hypergeometric,
[ [
( (
...@@ -274,7 +274,6 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -274,7 +274,6 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
), ),
], ],
pt.as_tensor([3, 2]), pt.as_tensor([3, 2]),
marks=pytest.mark.xfail, # Not implemented
), ),
( (
ptr.wald, ptr.wald,
...@@ -722,3 +721,34 @@ def test_repeated_args(): ...@@ -722,3 +721,34 @@ def test_repeated_args():
final_node = fn.maker.fgraph.outputs[0].owner final_node = fn.maker.fgraph.outputs[0].owner
assert isinstance(final_node.op, RandomVariableWithCoreShape) assert isinstance(final_node.op, RandomVariableWithCoreShape)
assert final_node.inputs[-2] is final_node.inputs[-1] assert final_node.inputs[-2] is final_node.inputs[-1]
def test_rv_fallback():
"""Test that random variables can fallback to object mode."""
class CustomRV(ptr.RandomVariable):
name = "custom"
signature = "()->()"
dtype = "float64"
def rng_fn(self, rng, value, size=None):
# Just return the value plus a random number
return value + rng.standard_normal(size=size)
custom_rv = CustomRV()
rng = shared(np.random.default_rng(123))
size = pt.scalar("size", dtype=int)
next_rng, x = custom_rv(np.pi, size=(size,), rng=rng).owner.outputs
fn = function([size], x, updates={rng: next_rng}, mode="NUMBA")
result1 = fn(1)
result2 = fn(1)
assert result1.shape == (1,)
assert result1 != result2
large_sample = fn(1000)
assert large_sample.shape == (1000,)
np.testing.assert_allclose(large_sample.mean(), np.pi, rtol=1e-2)
np.testing.assert_allclose(large_sample.std(), 1, rtol=1e-2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论