提交 90021636 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a Numba implementation for BernoulliRV

上级 febb92ea
......@@ -4,7 +4,7 @@ import warnings
from functools import reduce, singledispatch
from numbers import Number
from textwrap import dedent, indent
from typing import List, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numba
import numpy as np
......@@ -23,6 +23,7 @@ import aesara.tensor.random.basic as aer
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.type import Type
from aesara.link.utils import (
compile_function_src,
......@@ -2163,15 +2164,33 @@ def numba_funcify_RandomVariable(op, node, **kwargs):
return make_numba_random_fn(node, np_random_func)
@numba_funcify.register(aer.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
def create_numba_random_fn(
op: Op,
node: Apply,
scalar_fn: Callable[[str], str],
global_env: Optional[Dict[str, Any]] = None,
) -> Callable:
"""Create a vectorized function from a callable that generates the ``str`` function body.
TODO: This could/should be generalized for other simple function
construction cases that need unique-ified symbol names.
"""
np_random_fn_name = f"aesara_random_{get_name_for_object(op.name)}"
if global_env:
np_global_env = global_env.copy()
else:
np_global_env = {}
np_global_env["np"] = np
np_global_env["numba_vectorize"] = numba.vectorize
unique_names = unique_name_generator(
[
np_random_fn_name,
"numba_vectorize",
"np_standard_norm",
]
+ list(np_global_env.keys())
+ [
"rng",
"size",
"dtype",
......@@ -2181,17 +2200,38 @@ def numba_funcify_HalfNormalRV(op, node, **kwargs):
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]]
np_input_names = ", ".join(np_names)
np_global_env = {
"np_standard_norm": np.random.standard_normal,
"numba_vectorize": numba.vectorize,
}
np_random_fn_src = f"""
@numba_vectorize
def {np_random_fn_name}({np_input_names}):
return {np_names[0]} + {np_names[1]} * abs(np_standard_norm())
{scalar_fn(*np_names)}
"""
np_random_fn = compile_function_src(
np_random_fn_src, np_random_fn_name, np_global_env
)
return make_numba_random_fn(node, np_random_fn)
@numba_funcify.register(aer.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
def body_fn(a, b):
return f" return {a} + {b} * abs(np.random.normal(0, 1))"
return create_numba_random_fn(op, node, body_fn)
@numba_funcify.register(aer.BernoulliRV)
def numba_funcify_BernoulliRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
def body_fn(a):
return f"""
if {a} < np.random.uniform(0, 1):
return direct_cast(0, out_dtype)
else:
return direct_cast(1, out_dtype)
"""
return create_numba_random_fn(
op, node, body_fn, {"out_dtype": out_dtype, "direct_cast": direct_cast}
)
......@@ -2806,6 +2806,16 @@ def test_shared():
],
None,
),
(
aer.bernoulli,
[
set_test_value(
aet.dvector(),
np.array([0.1, 0.9], dtype=np.float64),
),
],
None,
),
(
aer.randint,
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论