提交 5213962b authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added Numba support for RandomVariable Ops

上级 46c772da
......@@ -3,7 +3,7 @@ import operator
import warnings
from functools import reduce, singledispatch
from numbers import Number
from textwrap import indent
from textwrap import dedent, indent
from typing import List, Union
import numba
......@@ -11,13 +11,15 @@ import numpy as np
import scipy
import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type
from numba import types
from numba import _helperlib, types
from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem
from numba.extending import box
from numba.np.unsafe.ndarray import to_fixed_tuple
from numpy.core.multiarray import normalize_axis_index
from numpy.random import RandomState
import aesara.tensor.random.basic as aer
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
......@@ -53,6 +55,7 @@ from aesara.tensor.basic import (
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
get_vector_length,
)
from aesara.tensor.blas import BatchedDot
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
......@@ -80,6 +83,8 @@ from aesara.tensor.nlinalg import (
QRFull,
)
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.var import RandomStateSharedVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.subtensor import (
......@@ -301,6 +306,14 @@ def numba_typify(data, dtype=None, **kwargs):
return data
@numba_typify.register(RandomState)
def numba_typify_RandomState(state, **kwargs):
ints, index = state.get_state()[1:3]
ptr = _helperlib.rnd_get_np_state_ptr()
_helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints]))
return ints
@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""
......@@ -1934,3 +1947,164 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
# NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba
def make_numba_random_fn(node, np_random_func):
"""Create Numba implementations for existing Numba-supported ``np.random`` functions.
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
tuple_size = get_vector_length(node.inputs[1])
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
bcast_fn_name = f"aesara_random_{get_name_for_object(np_random_func)}"
sized_fn_name = "sized_random_variable"
unique_names = unique_name_generator(
[
bcast_fn_name,
sized_fn_name,
"np",
"np_random_func",
"numba_vectorize",
"to_fixed_tuple",
"tuple_size",
"size_dims",
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
bcast_fn_input_names = ", ".join(
[unique_names(i, force_unique=True) for i in node.inputs[3:]]
)
bcast_fn_global_env = {
"np_random_func": np_random_func,
"numba_vectorize": numba.vectorize,
}
bcast_fn_src = f"""
@numba_vectorize
def {bcast_fn_name}({bcast_fn_input_names}):
return np_random_func({bcast_fn_input_names})
"""
bcast_fn = compile_function_src(bcast_fn_src, bcast_fn_name, bcast_fn_global_env)
random_fn_input_names = ", ".join(
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]
)
# Now, create a Numba JITable function that implements the `size` parameter
random_fn_global_env = {
bcast_fn_name: bcast_fn,
}
if tuple_size > 0:
random_fn_body = dedent(
f"""
size = to_fixed_tuple(size, tuple_size)
data = np.empty(size)
for i in np.ndindex(size[:size_dims]):
data[i] = {bcast_fn_name}({bcast_fn_input_names})
"""
)
random_fn_global_env.update(
{
"np": np,
"to_fixed_tuple": to_fixed_tuple,
"tuple_size": tuple_size,
"size_dims": size_dims,
}
)
else:
random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})"""
sized_fn_src = dedent(
f"""
def {sized_fn_name}({random_fn_input_names}):
{indent(random_fn_body, " " * 4)}
return (rng, data)
"""
)
random_fn = compile_function_src(sized_fn_src, sized_fn_name, random_fn_global_env)
random_fn = numba.njit(random_fn)
return random_fn
@numba_funcify.register(aer.UniformRV)
@numba_funcify.register(aer.TriangularRV)
@numba_funcify.register(aer.BetaRV)
@numba_funcify.register(aer.NormalRV)
@numba_funcify.register(aer.LogNormalRV)
@numba_funcify.register(aer.GammaRV)
@numba_funcify.register(aer.ChiSquareRV)
@numba_funcify.register(aer.ParetoRV)
@numba_funcify.register(aer.GumbelRV)
@numba_funcify.register(aer.ExponentialRV)
@numba_funcify.register(aer.WeibullRV)
@numba_funcify.register(aer.LogisticRV)
@numba_funcify.register(aer.VonMisesRV)
@numba_funcify.register(aer.PoissonRV)
@numba_funcify.register(aer.GeometricRV)
@numba_funcify.register(aer.HyperGeometricRV)
@numba_funcify.register(aer.CauchyRV)
@numba_funcify.register(aer.WaldRV)
@numba_funcify.register(aer.LaplaceRV)
@numba_funcify.register(aer.BinomialRV)
@numba_funcify.register(aer.NegBinomialRV)
@numba_funcify.register(aer.MultinomialRV)
@numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported
@numba_funcify.register(aer.PermutationRV)
def numba_funcify_RandomVariable(op, node, **kwargs):
name = op.name
np_random_func = getattr(np.random, name)
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
return make_numba_random_fn(node, np_random_func)
@numba_funcify.register(aer.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
np_random_fn_name = f"aesara_random_{get_name_for_object(op.name)}"
unique_names = unique_name_generator(
[
np_random_fn_name,
"numba_vectorize",
"np_standard_norm",
"rng",
"size",
"dtype",
],
suffix_sep="_",
)
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]]
np_input_names = ", ".join(np_names)
np_global_env = {
"np_standard_norm": np.random.standard_normal,
"numba_vectorize": numba.vectorize,
}
np_random_fn_src = f"""
@numba_vectorize
def {np_random_fn_name}({np_input_names}):
return {np_names[0]} + {np_names[1]} * abs(np_standard_norm())
"""
np_random_fn = compile_function_src(
np_random_fn_src, np_random_fn_name, np_global_env
)
return make_numba_random_fn(node, np_random_fn)
from numpy.random import RandomState
from aesara.link.basic import JITLinker
......@@ -16,11 +18,20 @@ class NumbaLinker(JITLinker):
return jitted_fn
def create_thunk_inputs(self, storage_map):
from aesara.link.numba.dispatch import numba_typify
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
# TODO:When RandomVariable conversion is implemented
# do RandomState typification over here.
if isinstance(sinput[0], RandomState):
new_value = numba_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
# We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# other non-Numba-fied graphs will have problems.
sinput = [new_value]
thunk_inputs.append(sinput)
return thunk_inputs
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论