提交 1c507090 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Use objmode in scipy.special without numba-scipy

上级 b2caa734
...@@ -1252,6 +1252,12 @@ def add_numba_configvars(): ...@@ -1252,6 +1252,12 @@ def add_numba_configvars():
BoolParam(True), BoolParam(True),
in_c_key=False, in_c_key=False,
) )
config.add(
"numba_scipy",
("Enable usage of the numba_scipy package for special functions",),
BoolParam(True),
in_c_key=False,
)
def _default_compiledirname(): def _default_compiledirname():
......
...@@ -323,9 +323,8 @@ def numba_typify(data, dtype=None, **kwargs): ...@@ -323,9 +323,8 @@ def numba_typify(data, dtype=None, **kwargs):
return data return data
@singledispatch def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
def numba_funcify(op, node=None, storage_map=None, **kwargs): """Create a Numba compatible function from an Aesara `Op`."""
"""Create a Numba compatible function from an PyTensor `Op`."""
warnings.warn( warnings.warn(
f"Numba will use object mode to run {op}'s perform method", f"Numba will use object mode to run {op}'s perform method",
...@@ -379,6 +378,12 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): ...@@ -379,6 +378,12 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
return perform return perform
@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Generate a numba function for a given op and apply node."""
return generate_fallback_impl(op, node, storage_map, **kwargs)
@numba_funcify.register(OpFromGraph) @numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs): def numba_funcify_OpFromGraph(op, node=None, **kwargs):
......
...@@ -27,6 +27,7 @@ from pytensor.scalar.basic import ( ...@@ -27,6 +27,7 @@ from pytensor.scalar.basic import (
OR, OR,
XOR, XOR,
Add, Add,
Composite,
IntDiv, IntDiv,
Mean, Mean,
Mul, Mul,
...@@ -40,6 +41,7 @@ from pytensor.scalar.basic import scalar_maximum ...@@ -40,6 +41,7 @@ from pytensor.scalar.basic import scalar_maximum
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar
@singledispatch @singledispatch
...@@ -424,8 +426,17 @@ def create_axis_apply_fn(fn, axis, ndim, dtype): ...@@ -424,8 +426,17 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
@numba_funcify.register(Elemwise) @numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs): def numba_funcify_Elemwise(op, node, **kwargs):
# Creating a new scalar node is more involved and unnecessary
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs) # if the scalar_op is composite, as the fgraph already contains
# all the necessary information.
scalar_node = None
if not isinstance(op.scalar_op, Composite):
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs)
scalar_op_fn = numba_funcify(
op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs
)
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__ elemwise_fn_name = elemwise_fn.__name__
......
import math import math
import warnings
from functools import reduce from functools import reduce
from typing import List from typing import List
...@@ -10,7 +11,11 @@ from pytensor import config ...@@ -10,7 +11,11 @@ from pytensor import config
from pytensor.compile.ops import ViewOp from pytensor.compile.ops import ViewOp
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import create_numba_signature, numba_funcify from pytensor.link.numba.dispatch.basic import (
create_numba_signature,
generate_fallback_impl,
numba_funcify,
)
from pytensor.link.utils import ( from pytensor.link.utils import (
compile_function_src, compile_function_src,
get_name_for_object, get_name_for_object,
...@@ -37,14 +42,31 @@ def numba_funcify_ScalarOp(op, node, **kwargs): ...@@ -37,14 +42,31 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
# compiling the same Numba function over and over again? # compiling the same Numba function over and over again?
scalar_func_name = op.nfunc_spec[0] scalar_func_name = op.nfunc_spec[0]
scalar_func = None
if scalar_func_name.startswith("scipy."): if scalar_func_name.startswith("scipy."):
func_package = scipy func_package = scipy
scalar_func_name = scalar_func_name.split(".", 1)[-1] scalar_func_name = scalar_func_name.split(".", 1)[-1]
use_numba_scipy = config.numba_scipy
if use_numba_scipy:
try:
import numba_scipy # noqa: F401
except ImportError:
use_numba_scipy = False
if not use_numba_scipy:
warnings.warn(
"Native numba versions of scipy functions might be "
"avalable if numba-scipy is installed.",
UserWarning,
)
scalar_func = generate_fallback_impl(op, node, **kwargs)
else: else:
func_package = np func_package = np
if "." in scalar_func_name: if scalar_func is not None:
pass
elif "." in scalar_func_name:
scalar_func = reduce(getattr, [scipy] + scalar_func_name.split(".")) scalar_func = reduce(getattr, [scipy] + scalar_func_name.split("."))
else: else:
scalar_func = getattr(func_package, scalar_func_name) scalar_func = getattr(func_package, scalar_func_name)
...@@ -220,7 +242,7 @@ def numba_funcify_Clip(op, **kwargs): ...@@ -220,7 +242,7 @@ def numba_funcify_Clip(op, **kwargs):
@numba_funcify.register(Composite) @numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs): def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True) signature = create_numba_signature(op.fgraph, force_scalar=True)
_ = kwargs.pop("storage_map", None) _ = kwargs.pop("storage_map", None)
......
...@@ -57,6 +57,12 @@ rng = np.random.default_rng(42849) ...@@ -57,6 +57,12 @@ rng = np.random.default_rng(42849)
lambda x: at.erfc(x), lambda x: at.erfc(x),
None, None,
), ),
(
[at.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: at.erfcx(x),
None,
),
( (
[at.vector() for i in range(4)], [at.vector() for i in range(4)],
[rng.standard_normal(100).astype(config.floatX) for i in range(4)], [rng.standard_normal(100).astype(config.floatX) for i in range(4)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论