提交 9124b72c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba fallback cython missing dtype

上级 a5fb9113
...@@ -69,15 +69,19 @@ def numba_funcify_ScalarOp(op, node, **kwargs): ...@@ -69,15 +69,19 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
cython_func = getattr(scipy.special.cython_special, scalar_func_name, None) cython_func = getattr(scipy.special.cython_special, scalar_func_name, None)
if cython_func is not None: if cython_func is not None:
scalar_func_numba = wrap_cython_function( try:
cython_func, output_dtype, input_dtypes scalar_func_numba = wrap_cython_function(
) cython_func, output_dtype, input_dtypes
has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch )
input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes() except NotImplementedError:
output_inner_dtype = scalar_func_numba.numpy_output_dtype() pass
else:
has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch
input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes()
output_inner_dtype = scalar_func_numba.numpy_output_dtype()
if scalar_func_numba is None: if scalar_func_numba is None:
scalar_func_numba = generate_fallback_impl(op, node, **kwargs) return generate_fallback_impl(op, node, **kwargs), None
scalar_op_fn_name = get_name_for_object(scalar_func_numba) scalar_op_fn_name = get_name_for_object(scalar_func_numba)
prefix = "x" if scalar_func_name != "x" else "y" prefix = "x" if scalar_func_name != "x" else "y"
......
import numpy as np import numpy as np
import pytest import pytest
import scipy
import pytensor.scalar as ps import pytensor.scalar as ps
import pytensor.scalar.basic as psb import pytensor.scalar.basic as psb
import pytensor.scalar.math as psm import pytensor.scalar.math as psm
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function from pytensor import config, function
from pytensor.graph import Apply
from pytensor.scalar import UnaryScalarOp
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
...@@ -184,3 +187,32 @@ def test_Softplus(dtype): ...@@ -184,3 +187,32 @@ def test_Softplus(dtype):
strict=True, strict=True,
err_msg=f"Failed for value {value}", err_msg=f"Failed for value {value}",
) )
def test_cython_obj_mode_fallback():
"""Test that unsupported cython signatures fallback to obj-mode"""
# Create a ScalarOp with a non-standard dtype
class IntegerGamma(UnaryScalarOp):
# We'll try to check for scipy cython impl
nfunc_spec = ("scipy.special.gamma", 1, 1)
def make_node(self, x):
x = psb.as_scalar(x)
assert x.dtype == "int64"
out = x.type()
return Apply(self, [x], [out])
def impl(self, x):
return scipy.special.gamma(x).astype("int64")
x = pt.scalar("x", dtype="int64")
g = Elemwise(IntegerGamma())(x)
assert g.type.dtype == "int64"
with pytest.warns(UserWarning, match="Numba will use object mode"):
compare_numba_and_py(
[x],
[g],
[np.array(5, dtype="int64")],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论