提交 5213962b authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added Numba support for RandomVariable Ops

上级 46c772da
......@@ -3,7 +3,7 @@ import operator
import warnings
from functools import reduce, singledispatch
from numbers import Number
from textwrap import indent
from textwrap import dedent, indent
from typing import List, Union
import numba
......@@ -11,13 +11,15 @@ import numpy as np
import scipy
import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type
from numba import types
from numba import _helperlib, types
from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem
from numba.extending import box
from numba.np.unsafe.ndarray import to_fixed_tuple
from numpy.core.multiarray import normalize_axis_index
from numpy.random import RandomState
import aesara.tensor.random.basic as aer
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
......@@ -53,6 +55,7 @@ from aesara.tensor.basic import (
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
get_vector_length,
)
from aesara.tensor.blas import BatchedDot
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
......@@ -80,6 +83,8 @@ from aesara.tensor.nlinalg import (
QRFull,
)
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.var import RandomStateSharedVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.subtensor import (
......@@ -301,6 +306,14 @@ def numba_typify(data, dtype=None, **kwargs):
return data
@numba_typify.register(RandomState)
def numba_typify_RandomState(state, **kwargs):
ints, index = state.get_state()[1:3]
ptr = _helperlib.rnd_get_np_state_ptr()
_helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints]))
return ints
@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""
......@@ -1934,3 +1947,164 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
# NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba
def make_numba_random_fn(node, np_random_func):
"""Create Numba implementations for existing Numba-supported ``np.random`` functions.
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
tuple_size = get_vector_length(node.inputs[1])
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
bcast_fn_name = f"aesara_random_{get_name_for_object(np_random_func)}"
sized_fn_name = "sized_random_variable"
unique_names = unique_name_generator(
[
bcast_fn_name,
sized_fn_name,
"np",
"np_random_func",
"numba_vectorize",
"to_fixed_tuple",
"tuple_size",
"size_dims",
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
bcast_fn_input_names = ", ".join(
[unique_names(i, force_unique=True) for i in node.inputs[3:]]
)
bcast_fn_global_env = {
"np_random_func": np_random_func,
"numba_vectorize": numba.vectorize,
}
bcast_fn_src = f"""
@numba_vectorize
def {bcast_fn_name}({bcast_fn_input_names}):
return np_random_func({bcast_fn_input_names})
"""
bcast_fn = compile_function_src(bcast_fn_src, bcast_fn_name, bcast_fn_global_env)
random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]
)
# Now, create a Numba JITable function that implements the `size` parameter
random_fn_global_env = {
bcast_fn_name: bcast_fn,
}
if tuple_size > 0:
random_fn_body = dedent(
f"""
size = to_fixed_tuple(size, tuple_size)
data = np.empty(size)
for i in np.ndindex(size[:size_dims]):
data[i] = {bcast_fn_name}({bcast_fn_input_names})
"""
)
random_fn_global_env.update(
{
"np": np,
"to_fixed_tuple": to_fixed_tuple,
"tuple_size": tuple_size,
"size_dims": size_dims,
}
)
else:
random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})"""
sized_fn_src = dedent(
f"""
def {sized_fn_name}({random_fn_input_names}):
{indent(random_fn_body, " " * 4)}
return (rng, data)
"""
)
random_fn = compile_function_src(sized_fn_src, sized_fn_name, random_fn_global_env)
random_fn = numba.njit(random_fn)
return random_fn
@numba_funcify.register(aer.UniformRV)
@numba_funcify.register(aer.TriangularRV)
@numba_funcify.register(aer.BetaRV)
@numba_funcify.register(aer.NormalRV)
@numba_funcify.register(aer.LogNormalRV)
@numba_funcify.register(aer.GammaRV)
@numba_funcify.register(aer.ChiSquareRV)
@numba_funcify.register(aer.ParetoRV)
@numba_funcify.register(aer.GumbelRV)
@numba_funcify.register(aer.ExponentialRV)
@numba_funcify.register(aer.WeibullRV)
@numba_funcify.register(aer.LogisticRV)
@numba_funcify.register(aer.VonMisesRV)
@numba_funcify.register(aer.PoissonRV)
@numba_funcify.register(aer.GeometricRV)
@numba_funcify.register(aer.HyperGeometricRV)
@numba_funcify.register(aer.CauchyRV)
@numba_funcify.register(aer.WaldRV)
@numba_funcify.register(aer.LaplaceRV)
@numba_funcify.register(aer.BinomialRV)
@numba_funcify.register(aer.NegBinomialRV)
@numba_funcify.register(aer.MultinomialRV)
@numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported
@numba_funcify.register(aer.PermutationRV)
def numba_funcify_RandomVariable(op, node, **kwargs):
name = op.name
np_random_func = getattr(np.random, name)
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
return make_numba_random_fn(node, np_random_func)
@numba_funcify.register(aer.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
np_random_fn_name = f"aesara_random_{get_name_for_object(op.name)}"
unique_names = unique_name_generator(
[
np_random_fn_name,
"numba_vectorize",
"np_standard_norm",
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]]
np_input_names = ", ".join(np_names)
np_global_env = {
"np_standard_norm": np.random.standard_normal,
"numba_vectorize": numba.vectorize,
}
np_random_fn_src = f"""
@numba_vectorize
def {np_random_fn_name}({np_input_names}):
return {np_names[0]} + {np_names[1]} * abs(np_standard_norm())
"""
np_random_fn = compile_function_src(
np_random_fn_src, np_random_fn_name, np_global_env
)
return make_numba_random_fn(node, np_random_fn)
from numpy.random import RandomState
from aesara.link.basic import JITLinker
......@@ -16,11 +18,20 @@ class NumbaLinker(JITLinker):
return jitted_fn
def create_thunk_inputs(self, storage_map):
from aesara.link.numba.dispatch import numba_typify
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
# TODO:When RandomVariable conversion is implemented
# do RandomState typification over here.
if isinstance(sinput[0], RandomState):
new_value = numba_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
# We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# other non-Numba-fied graphs will have problems.
sinput = [new_value]
thunk_inputs.append(sinput)
return thunk_inputs
......@@ -14,6 +14,7 @@ import aesara.tensor.basic as aetb
import aesara.tensor.inplace as ati
import aesara.tensor.math as aem
import aesara.tensor.nnet.basic as nnetb
import aesara.tensor.random.basic as aer
from aesara import config, shared
from aesara.compile.function import function
from aesara.compile.mode import Mode
......@@ -138,7 +139,10 @@ def eval_python_only(fn_inputs, fgraph, inputs):
return inner_vec
return wrap
if len(args) == 1 and callable(args[0]):
return wrap(args[0], **kwargs)
else:
return wrap
with mock.patch("aesara.link.numba.dispatch.numba.njit", njit_noop), mock.patch(
"aesara.link.numba.dispatch.numba.vectorize",
......@@ -2469,3 +2473,383 @@ def test_shared():
numba_res = aesara_numba_fn()
np.testing.assert_allclose(numba_res, new_a_value * 2)
@pytest.mark.parametrize(
"rv_op, dist_args, size",
[
(
aer.normal,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
(
aer.uniform,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
(
aer.triangular,
[
set_test_value(
aet.dscalar(),
np.array(-5.0, dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(5.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
pytest.param(
aer.beta,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
(
aer.lognormal,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
pytest.param(
aer.gamma,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
pytest.param(
aer.chisquare,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
aet.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
pytest.param(
aer.pareto,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Not implemented"),
),
pytest.param(
aer.gumbel,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
(
aer.exponential,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
(
aer.weibull,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
(
aer.logistic,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
pytest.param(
aer.vonmises,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Numba and NumPy rng states do not match"),
),
(
aer.geometric,
[
set_test_value(
aet.dvector(),
np.array([0.3, 0.4], dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
(
aer.hypergeometric,
[
set_test_value(
aet.lscalar(),
np.array(7, dtype=np.int64),
),
set_test_value(
aet.lscalar(),
np.array(8, dtype=np.int64),
),
set_test_value(
aet.lscalar(),
np.array(15, dtype=np.int64),
),
],
aet.as_tensor([3, 2]),
),
pytest.param(
aer.cauchy,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Not implemented"),
),
(
aer.wald,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
(
aer.laplace,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
(
aer.binomial,
[
set_test_value(
aet.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
aet.dscalar(),
np.array(0.9, dtype=np.float64),
),
],
aet.as_tensor([3, 2]),
),
# pytest.param(
# aer.negative_binomial,
# [
# set_test_value(
# aet.lvector(),
# np.array([1, 2], dtype=np.int64),
# ),
# set_test_value(
# aet.dscalar(),
# np.array(0.9, dtype=np.float64),
# ),
# ],
# aet.as_tensor([3, 2]),
# marks=pytest.mark.xfail(reason="Not implemented"),
# ),
(
aer.normal,
[
set_test_value(
aet.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
aet.as_tensor(tuple(set_test_value(aet.lscalar(), v) for v in [3, 2])),
),
(
aer.poisson,
[
set_test_value(
aet.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
],
None,
),
(
aer.halfnormal,
[
set_test_value(
aet.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
aet.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
None,
),
(
aer.randint,
[
set_test_value(
aet.lscalar(),
np.array(0, dtype=np.int64),
),
set_test_value(
aet.lscalar(),
np.array(5, dtype=np.int64),
),
],
aet.as_tensor([3, 2]),
),
pytest.param(
aer.multivariate_normal,
[
set_test_value(
aet.dmatrix(),
np.array([[1, 2], [3, 4]], dtype=np.float64),
),
set_test_value(
aet.tensor("float64", [True, False, False]),
np.eye(2)[None, ...],
),
],
aet.as_tensor(tuple(set_test_value(aet.lscalar(), v) for v in [4, 3, 2])),
marks=pytest.mark.xfail(reason="Not implemented"),
),
],
ids=str,
)
def test_RandomVariable(rv_op, dist_args, size):
rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
def test_random_Generator():
rng = shared(np.random.default_rng(29402))
g = aer.normal(rng=rng)
g_fg = FunctionGraph(outputs=[g])
with pytest.raises(TypeError):
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论