提交 5213962b authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added Numba support for RandomVariable Ops

上级 46c772da
...@@ -3,7 +3,7 @@ import operator ...@@ -3,7 +3,7 @@ import operator
import warnings import warnings
from functools import reduce, singledispatch from functools import reduce, singledispatch
from numbers import Number from numbers import Number
from textwrap import indent from textwrap import dedent, indent
from typing import List, Union from typing import List, Union
import numba import numba
...@@ -11,13 +11,15 @@ import numpy as np ...@@ -11,13 +11,15 @@ import numpy as np
import scipy import scipy
import scipy.special import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type from llvmlite.llvmpy.core import Type as llvm_Type
from numba import types from numba import _helperlib, types
from numba.core.errors import TypingError from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem from numba.cpython.unsafe.tuple import tuple_setitem
from numba.extending import box from numba.extending import box
from numba.np.unsafe.ndarray import to_fixed_tuple from numba.np.unsafe.ndarray import to_fixed_tuple
from numpy.core.multiarray import normalize_axis_index from numpy.core.multiarray import normalize_axis_index
from numpy.random import RandomState
import aesara.tensor.random.basic as aer
from aesara.compile.ops import DeepCopyOp, ViewOp from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -53,6 +55,7 @@ from aesara.tensor.basic import ( ...@@ -53,6 +55,7 @@ from aesara.tensor.basic import (
Rebroadcast, Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
TensorFromScalar, TensorFromScalar,
get_vector_length,
) )
from aesara.tensor.blas import BatchedDot from aesara.tensor.blas import BatchedDot
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
...@@ -80,6 +83,8 @@ from aesara.tensor.nlinalg import ( ...@@ -80,6 +83,8 @@ from aesara.tensor.nlinalg import (
QRFull, QRFull,
) )
from aesara.tensor.nnet.basic import LogSoftmax, Softmax from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.var import RandomStateSharedVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
...@@ -301,6 +306,14 @@ def numba_typify(data, dtype=None, **kwargs): ...@@ -301,6 +306,14 @@ def numba_typify(data, dtype=None, **kwargs):
return data return data
@numba_typify.register(RandomState)
def numba_typify_RandomState(state, **kwargs):
ints, index = state.get_state()[1:3]
ptr = _helperlib.rnd_get_np_state_ptr()
_helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints]))
return ints
@singledispatch @singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs): def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`.""" """Create a Numba compatible function from an Aesara `Op`."""
...@@ -1934,3 +1947,164 @@ def numba_funcify_BatchedDot(op, node, **kwargs): ...@@ -1934,3 +1947,164 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
# NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because # NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM # they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba # optimizations are apparently already performed by Numba
def make_numba_random_fn(node, np_random_func):
"""Create Numba implementations for existing Numba-supported ``np.random`` functions.
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
tuple_size = get_vector_length(node.inputs[1])
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
bcast_fn_name = f"aesara_random_{get_name_for_object(np_random_func)}"
sized_fn_name = "sized_random_variable"
unique_names = unique_name_generator(
[
bcast_fn_name,
sized_fn_name,
"np",
"np_random_func",
"numba_vectorize",
"to_fixed_tuple",
"tuple_size",
"size_dims",
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
bcast_fn_input_names = ", ".join(
[unique_names(i, force_unique=True) for i in node.inputs[3:]]
)
bcast_fn_global_env = {
"np_random_func": np_random_func,
"numba_vectorize": numba.vectorize,
}
bcast_fn_src = f"""
@numba_vectorize
def {bcast_fn_name}({bcast_fn_input_names}):
return np_random_func({bcast_fn_input_names})
"""
bcast_fn = compile_function_src(bcast_fn_src, bcast_fn_name, bcast_fn_global_env)
random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]
)
# Now, create a Numba JITable function that implements the `size` parameter
random_fn_global_env = {
bcast_fn_name: bcast_fn,
}
if tuple_size > 0:
random_fn_body = dedent(
f"""
size = to_fixed_tuple(size, tuple_size)
data = np.empty(size)
for i in np.ndindex(size[:size_dims]):
data[i] = {bcast_fn_name}({bcast_fn_input_names})
"""
)
random_fn_global_env.update(
{
"np": np,
"to_fixed_tuple": to_fixed_tuple,
"tuple_size": tuple_size,
"size_dims": size_dims,
}
)
else:
random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})"""
sized_fn_src = dedent(
f"""
def {sized_fn_name}({random_fn_input_names}):
{indent(random_fn_body, " " * 4)}
return (rng, data)
"""
)
random_fn = compile_function_src(sized_fn_src, sized_fn_name, random_fn_global_env)
random_fn = numba.njit(random_fn)
return random_fn
@numba_funcify.register(aer.UniformRV)
@numba_funcify.register(aer.TriangularRV)
@numba_funcify.register(aer.BetaRV)
@numba_funcify.register(aer.NormalRV)
@numba_funcify.register(aer.LogNormalRV)
@numba_funcify.register(aer.GammaRV)
@numba_funcify.register(aer.ChiSquareRV)
@numba_funcify.register(aer.ParetoRV)
@numba_funcify.register(aer.GumbelRV)
@numba_funcify.register(aer.ExponentialRV)
@numba_funcify.register(aer.WeibullRV)
@numba_funcify.register(aer.LogisticRV)
@numba_funcify.register(aer.VonMisesRV)
@numba_funcify.register(aer.PoissonRV)
@numba_funcify.register(aer.GeometricRV)
@numba_funcify.register(aer.HyperGeometricRV)
@numba_funcify.register(aer.CauchyRV)
@numba_funcify.register(aer.WaldRV)
@numba_funcify.register(aer.LaplaceRV)
@numba_funcify.register(aer.BinomialRV)
@numba_funcify.register(aer.NegBinomialRV)
@numba_funcify.register(aer.MultinomialRV)
@numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported
@numba_funcify.register(aer.PermutationRV)
def numba_funcify_RandomVariable(op, node, **kwargs):
name = op.name
np_random_func = getattr(np.random, name)
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
return make_numba_random_fn(node, np_random_func)
@numba_funcify.register(aer.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
np_random_fn_name = f"aesara_random_{get_name_for_object(op.name)}"
unique_names = unique_name_generator(
[
np_random_fn_name,
"numba_vectorize",
"np_standard_norm",
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]]
np_input_names = ", ".join(np_names)
np_global_env = {
"np_standard_norm": np.random.standard_normal,
"numba_vectorize": numba.vectorize,
}
np_random_fn_src = f"""
@numba_vectorize
def {np_random_fn_name}({np_input_names}):
return {np_names[0]} + {np_names[1]} * abs(np_standard_norm())
"""
np_random_fn = compile_function_src(
np_random_fn_src, np_random_fn_name, np_global_env
)
return make_numba_random_fn(node, np_random_fn)
from numpy.random import RandomState
from aesara.link.basic import JITLinker from aesara.link.basic import JITLinker
...@@ -16,11 +18,20 @@ class NumbaLinker(JITLinker): ...@@ -16,11 +18,20 @@ class NumbaLinker(JITLinker):
return jitted_fn return jitted_fn
def create_thunk_inputs(self, storage_map): def create_thunk_inputs(self, storage_map):
from aesara.link.numba.dispatch import numba_typify
thunk_inputs = [] thunk_inputs = []
for n in self.fgraph.inputs: for n in self.fgraph.inputs:
sinput = storage_map[n] sinput = storage_map[n]
# TODO:When RandomVariable conversion is implemented if isinstance(sinput[0], RandomState):
# do RandomState typification over here. new_value = numba_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
# We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# other non-Numba-fied graphs will have problems.
sinput = [new_value]
thunk_inputs.append(sinput) thunk_inputs.append(sinput)
return thunk_inputs return thunk_inputs
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论