提交 98be9c5f authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Replace numba_scipy

上级 1c507090
......@@ -1252,12 +1252,6 @@ def add_numba_configvars():
BoolParam(True),
in_c_key=False,
)
config.add(
"numba_scipy",
("Enable usage of the numba_scipy package for special functions",),
BoolParam(True),
in_c_key=False,
)
def _default_compiledirname():
......
import ctypes
import importlib
import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
import numba
import numpy as np
from numpy.typing import DTypeLike
from scipy import LowLevelCallable
_C_TO_NUMPY: Dict[str, DTypeLike] = {
"bool": np.bool_,
"signed char": np.byte,
"unsigned char": np.ubyte,
"short": np.short,
"unsigned short": np.ushort,
"int": np.intc,
"unsigned int": np.uintc,
"long": np.int_,
"unsigned long": np.uint,
"long long": np.longlong,
"float": np.single,
"double": np.double,
"long double": np.longdouble,
"float complex": np.csingle,
"double complex": np.cdouble,
}
@dataclass
class Signature:
res_dtype: DTypeLike
res_c_type: str
arg_dtypes: List[DTypeLike]
arg_c_types: List[str]
arg_names: List[Optional[str]]
@property
def arg_numba_types(self) -> List[DTypeLike]:
return [numba.from_dtype(dtype) for dtype in self.arg_dtypes]
def can_cast_args(self, args: List[DTypeLike]) -> bool:
ok = True
count = 0
for name, dtype in zip(self.arg_names, self.arg_dtypes):
if name == "__pyx_skip_dispatch":
continue
if len(args) <= count:
raise ValueError("Incorrect number of arguments")
ok &= np.can_cast(args[count], dtype)
count += 1
if count != len(args):
return False
return ok
def provides(self, restype: DTypeLike, arg_dtypes: List[DTypeLike]) -> bool:
args_ok = self.can_cast_args(arg_dtypes)
if np.issubdtype(restype, np.inexact):
result_ok = np.can_cast(self.res_dtype, restype, casting="same_kind")
# We do not want to provide less accuracy than advertised
result_ok &= np.dtype(self.res_dtype).itemsize >= np.dtype(restype).itemsize
else:
result_ok = np.can_cast(self.res_dtype, restype)
return args_ok and result_ok
@staticmethod
def from_c_types(signature: bytes) -> "Signature":
# Match strings like "double(int, double)"
# and extract the return type and the joined arguments
expr = re.compile(rb"\s*(?P<restype>[\w ]*\w+)\s*\((?P<args>[\w\s,]*)\)")
re_match = re.fullmatch(expr, signature)
if re_match is None:
raise ValueError(f"Invalid signature: {signature.decode()}")
groups = re_match.groupdict()
res_c_type = groups["restype"].decode()
res_dtype: DTypeLike = _C_TO_NUMPY[res_c_type]
raw_args = groups["args"]
decl_expr = re.compile(
rb"\s*(?P<type>((long )|(unsigned )|(signed )|(double )|)"
rb"((double)|(float)|(int)|(short)|(char)|(long)|(bool)|(complex)))"
rb"(\s(?P<name>[\w_]*))?\s*"
)
arg_dtypes = []
arg_names: List[Optional[str]] = []
arg_c_types = []
for raw_arg in raw_args.split(b","):
re_match = re.fullmatch(decl_expr, raw_arg)
if re_match is None:
raise ValueError(f"Invalid signature: {signature.decode()}")
groups = re_match.groupdict()
arg_c_type = groups["type"].decode()
try:
arg_dtype = _C_TO_NUMPY[arg_c_type]
except KeyError:
raise ValueError(f"Unknown C type: {arg_c_type}")
arg_c_types.append(arg_c_type)
arg_dtypes.append(arg_dtype)
name = groups["name"]
if not name:
arg_names.append(None)
else:
arg_names.append(name.decode())
return Signature(res_dtype, res_c_type, arg_dtypes, arg_c_types, arg_names)
def _available_impls(func: Callable) -> List[Tuple[Signature, Any]]:
"""Find all available implementations for a fused cython function."""
impls = []
mod = importlib.import_module(func.__module__)
signatures = getattr(func, "__signatures__", None)
if signatures is not None:
# Cython function with __signatures__ should be fused and thus
# indexable
func_map = cast(Mapping, func)
candidates = [func_map[key] for key in signatures]
else:
candidates = [func]
for candidate in candidates:
name = candidate.__name__
capsule = mod.__pyx_capi__[name]
llc = LowLevelCallable(capsule)
try:
signature = Signature.from_c_types(llc.signature.encode())
except KeyError:
continue
impls.append((signature, capsule))
return impls
class _CythonWrapper(numba.types.WrapperAddressProtocol):
def __init__(self, pyfunc, signature, capsule):
self._keep_alive = capsule
get_name = ctypes.pythonapi.PyCapsule_GetName
get_name.restype = ctypes.c_char_p
get_name.argtypes = (ctypes.py_object,)
raw_signature = get_name(capsule)
get_pointer = ctypes.pythonapi.PyCapsule_GetPointer
get_pointer.restype = ctypes.c_void_p
get_pointer.argtypes = (ctypes.py_object, ctypes.c_char_p)
self._func_ptr = get_pointer(capsule, raw_signature)
self._signature = signature
self._pyfunc = pyfunc
def signature(self):
return numba.from_dtype(self._signature.res_dtype)(
*self._signature.arg_numba_types
)
def __wrapper_address__(self):
return self._func_ptr
def __call__(self, *args, **kwargs):
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
if self.has_pyx_skip_dispatch():
output = self._pyfunc(*args[:-1], **kwargs)
else:
output = self._pyfunc(*args, **kwargs)
return self._signature.res_dtype(output)
def has_pyx_skip_dispatch(self):
if not self._signature.arg_names:
return False
if any(
name == "__pyx_skip_dispatch" for name in self._signature.arg_names[:-1]
):
raise ValueError("skip_dispatch parameter must be last")
return self._signature.arg_names[-1] == "__pyx_skip_dispatch"
def numpy_arg_dtypes(self):
return self._signature.arg_dtypes
def numpy_output_dtype(self):
return self._signature.res_dtype
def wrap_cython_function(func, restype, arg_types):
impls = _available_impls(func)
compatible = []
for sig, capsule in impls:
if sig.provides(restype, arg_types):
compatible.append((sig, capsule))
def sort_key(args):
sig, _ = args
# Prefer functions with less inputs bytes
argsize = sum(np.dtype(dtype).itemsize for dtype in sig.arg_dtypes)
# Prefer functions with more exact (integer) arguments
num_inexact = sum(np.issubdtype(dtype, np.inexact) for dtype in sig.arg_dtypes)
return (num_inexact, argsize)
compatible.sort(key=sort_key)
if not compatible:
raise NotImplementedError(f"Could not find a compatible impl of {func}")
sig, capsule = compatible[0]
return _CythonWrapper(func, sig, capsule)
import math
import warnings
from functools import reduce
from typing import List
import numpy as np
import scipy
import scipy.special
from pytensor import config
from pytensor.compile.ops import ViewOp
......@@ -16,6 +12,7 @@ from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
numba_funcify,
)
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
from pytensor.link.utils import (
compile_function_src,
get_name_for_object,
......@@ -41,86 +38,83 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
scalar_func_name = op.nfunc_spec[0]
scalar_func = None
if scalar_func_name.startswith("scipy."):
func_package = scipy
scalar_func_name = scalar_func_name.split(".", 1)[-1]
use_numba_scipy = config.numba_scipy
if use_numba_scipy:
try:
import numba_scipy # noqa: F401
except ImportError:
use_numba_scipy = False
if not use_numba_scipy:
warnings.warn(
"Native numba versions of scipy functions might be "
"avalable if numba-scipy is installed.",
UserWarning,
scalar_func_path = op.nfunc_spec[0]
scalar_func_numba = None
*module_path, scalar_func_name = scalar_func_path.split(".")
if not module_path:
# Assume it is numpy, and numba has an implementation
scalar_func_numba = getattr(np, scalar_func_name)
input_dtypes = [np.dtype(input.type.dtype) for input in node.inputs]
output_dtypes = [np.dtype(output.type.dtype) for output in node.outputs]
if len(output_dtypes) != 1:
raise ValueError("ScalarOps with more than one output are not supported")
output_dtype = output_dtypes[0]
input_inner_dtypes = None
output_inner_dtype = None
# Cython functions might have an additonal argument
has_pyx_skip_dispatch = False
if scalar_func_path.startswith("scipy.special"):
import scipy.special.cython_special
cython_func = getattr(scipy.special.cython_special, scalar_func_name, None)
if cython_func is not None:
# try:
scalar_func_numba = wrap_cython_function(
cython_func, output_dtype, input_dtypes
)
scalar_func = generate_fallback_impl(op, node, **kwargs)
else:
func_package = np
has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch
input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes()
output_inner_dtype = scalar_func_numba.numpy_output_dtype()
# except NotImplementedError:
# pass
if scalar_func is not None:
pass
elif "." in scalar_func_name:
scalar_func = reduce(getattr, [scipy] + scalar_func_name.split("."))
else:
scalar_func = getattr(func_package, scalar_func_name)
if scalar_func_numba is None:
scalar_func_numba = generate_fallback_impl(op, node, **kwargs)
scalar_op_fn_name = get_name_for_object(scalar_func)
scalar_op_fn_name = get_name_for_object(scalar_func_numba)
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func"], suffix_sep="_"
[scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_"
)
global_env = {"scalar_func": scalar_func}
global_env = {"scalar_func_numba": scalar_func_numba}
input_tmp_dtypes = None
if func_package == scipy and hasattr(scalar_func, "types"):
# The `numba-scipy` bindings don't provide implementations for all
# inputs types, so we need to convert the inputs to floats and back.
inp_dtype_kinds = tuple(np.dtype(inp.type.dtype).kind for inp in node.inputs)
accepted_inp_kinds = tuple(
sig_type.split("->")[0] for sig_type in scalar_func.types
)
if not any(
all(dk == ik for dk, ik in zip(inp_dtype_kinds, ok_kinds))
for ok_kinds in accepted_inp_kinds
):
# They're usually ordered from lower-to-higher precision, so
# we pick the last acceptable input types
#
# XXX: We should pick the first acceptable float/int types in
# reverse, excluding all the incompatible ones (e.g. `"0"`).
# The assumption is that this is only used by `numba-scipy`-exposed
# functions, although it's possible for this to be triggered by
# something else from the `scipy` package
input_tmp_dtypes = tuple(np.dtype(k) for k in accepted_inp_kinds[-1])
if input_tmp_dtypes is None:
if input_inner_dtypes is None and output_inner_dtype is None:
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func"], suffix_sep="_"
[scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_"
)
input_names = ", ".join(
[unique_names(v, force_unique=True) for v in node.inputs]
)
scalar_op_src = f"""
if not has_pyx_skip_dispatch:
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}):
return scalar_func_numba({input_names})
"""
else:
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}):
return scalar_func({input_names})
"""
return scalar_func_numba({input_names}, np.intc(1))
"""
else:
global_env["direct_cast"] = numba_basic.direct_cast
global_env["output_dtype"] = np.dtype(node.outputs[0].type.dtype)
global_env["output_dtype"] = np.dtype(output_inner_dtype)
input_tmp_dtype_names = {
f"inp_tmp_dtype_{i}": i_dtype for i, i_dtype in enumerate(input_tmp_dtypes)
f"inp_tmp_dtype_{i}": i_dtype
for i, i_dtype in enumerate(input_inner_dtypes)
}
global_env.update(input_tmp_dtype_names)
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func"] + list(global_env.keys()), suffix_sep="_"
[scalar_op_fn_name, "scalar_func_numba"] + list(global_env.keys()),
suffix_sep="_",
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
......@@ -132,10 +126,16 @@ def {scalar_op_fn_name}({input_names}):
)
]
)
scalar_op_src = f"""
if not has_pyx_skip_dispatch:
scalar_op_src = f"""
def {scalar_op_fn_name}({', '.join(input_names)}):
return direct_cast(scalar_func_numba({converted_call_args}), output_dtype)
"""
else:
scalar_op_src = f"""
def {scalar_op_fn_name}({', '.join(input_names)}):
return direct_cast(scalar_func({converted_call_args}), output_dtype)
"""
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
"""
scalar_op_fn = compile_function_src(
scalar_op_src, scalar_op_fn_name, {**globals(), **global_env}
......
import numpy as np
import pytest
import scipy.special.cython_special
from numba.types import float32, float64, int32, int64
from aesara.link.numba.dispatch.cython_support import Signature, wrap_cython_function
@pytest.mark.parametrize(
"sig, expected_result, expected_args",
[
(b"double(double)", np.float64, [np.float64]),
(b"float(unsigned int)", np.float32, [np.uintc]),
(b"unsigned char(unsigned short foo)", np.ubyte, [np.ushort]),
(
b"unsigned char(unsigned short foo, double bar)",
np.ubyte,
[np.ushort, np.float64],
),
],
)
def test_parse_signature(sig, expected_result, expected_args):
actual = Signature.from_c_types(sig)
assert actual.res_dtype == expected_result
assert actual.arg_dtypes == expected_args
@pytest.mark.parametrize(
"have, want, should_provide",
[
(b"double(int)", b"float(int)", True),
(b"float(int)", b"double(int)", False),
(b"double(unsigned short)", b"double(unsigned char)", True),
(b"double(unsigned char)", b"double(short)", False),
(b"short(double)", b"int(double)", True),
(b"int(double)", b"short(double)", False),
(b"float(double, int)", b"float(double, short)", True),
],
)
def test_signature_provides(have, want, should_provide):
have = Signature.from_c_types(have)
want = Signature.from_c_types(want)
provides = have.provides(want.res_dtype, want.arg_dtypes)
assert provides == should_provide
@pytest.mark.parametrize(
"func, output, inputs, expected",
[
(
scipy.special.cython_special.agm,
np.float64,
[np.float64, np.float64],
float64(float64, float64, int32),
),
(
scipy.special.cython_special.erfc,
np.float64,
[np.float64],
float64(float64, int32),
),
(
scipy.special.cython_special.expit,
np.float32,
[np.float32],
float32(float32, int32),
),
(
scipy.special.cython_special.expit,
np.float64,
[np.float64],
float64(float64, int32),
),
(
# expn doesn't have a float32 implementation
scipy.special.cython_special.expn,
np.float32,
[np.float32, np.float32],
float64(float64, float64, int32),
),
(
# We choose the integer implementation if possible
scipy.special.cython_special.expn,
np.float32,
[np.int64, np.float32],
float64(int64, float64, int32),
),
],
)
def test_choose_signature(func, output, inputs, expected):
wrapper = wrap_cython_function(func, output, inputs)
assert wrapper.signature() == expected
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论