提交 1c507090 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Use objmode in scipy.special without numba-scipy

上级 b2caa734
......@@ -1252,6 +1252,12 @@ def add_numba_configvars():
BoolParam(True),
in_c_key=False,
)
config.add(
"numba_scipy",
("Enable usage of the numba_scipy package for special functions",),
BoolParam(True),
in_c_key=False,
)
def _default_compiledirname():
......
......@@ -323,9 +323,8 @@ def numba_typify(data, dtype=None, **kwargs):
return data
@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an PyTensor `Op`."""
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""
warnings.warn(
f"Numba will use object mode to run {op}'s perform method",
......@@ -379,6 +378,12 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
return perform
@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Generate a numba function for a given op and apply node."""
return generate_fallback_impl(op, node, storage_map, **kwargs)
@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
......
......@@ -27,6 +27,7 @@ from pytensor.scalar.basic import (
OR,
XOR,
Add,
Composite,
IntDiv,
Mean,
Mul,
......@@ -40,6 +41,7 @@ from pytensor.scalar.basic import scalar_maximum
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar
@singledispatch
......@@ -424,8 +426,17 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
# Creating a new scalar node is more involved and unnecessary
# if the scalar_op is composite, as the fgraph already contains
# all the necessary information.
scalar_node = None
if not isinstance(op.scalar_op, Composite):
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs)
scalar_op_fn = numba_funcify(
op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs
)
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
......
import math
import warnings
from functools import reduce
from typing import List
......@@ -10,7 +11,11 @@ from pytensor import config
from pytensor.compile.ops import ViewOp
from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import create_numba_signature, numba_funcify
from pytensor.link.numba.dispatch.basic import (
create_numba_signature,
generate_fallback_impl,
numba_funcify,
)
from pytensor.link.utils import (
compile_function_src,
get_name_for_object,
......@@ -37,14 +42,31 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
# compiling the same Numba function over and over again?
scalar_func_name = op.nfunc_spec[0]
scalar_func = None
if scalar_func_name.startswith("scipy."):
func_package = scipy
scalar_func_name = scalar_func_name.split(".", 1)[-1]
use_numba_scipy = config.numba_scipy
if use_numba_scipy:
try:
import numba_scipy # noqa: F401
except ImportError:
use_numba_scipy = False
if not use_numba_scipy:
warnings.warn(
"Native numba versions of scipy functions might be "
"avalable if numba-scipy is installed.",
UserWarning,
)
scalar_func = generate_fallback_impl(op, node, **kwargs)
else:
func_package = np
if "." in scalar_func_name:
if scalar_func is not None:
pass
elif "." in scalar_func_name:
scalar_func = reduce(getattr, [scipy] + scalar_func_name.split("."))
else:
scalar_func = getattr(func_package, scalar_func_name)
......@@ -220,7 +242,7 @@ def numba_funcify_Clip(op, **kwargs):
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
signature = create_numba_signature(op.fgraph, force_scalar=True)
_ = kwargs.pop("storage_map", None)
......
......@@ -57,6 +57,12 @@ rng = np.random.default_rng(42849)
lambda x: at.erfc(x),
None,
),
(
[at.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: at.erfcx(x),
None,
),
(
[at.vector() for i in range(4)],
[rng.standard_normal(100).astype(config.floatX) for i in range(4)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论