提交 4ccf184f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert SciPy scalar function inputs to acceptable dtypes in Numba implementations

上级 033bf332
...@@ -52,14 +52,67 @@ def numba_funcify_ScalarOp(op, node, **kwargs): ...@@ -52,14 +52,67 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
[scalar_op_fn_name, "scalar_func"], suffix_sep="_" [scalar_op_fn_name, "scalar_func"], suffix_sep="_"
) )
input_names = ", ".join([unique_names(v, force_unique=True) for v in node.inputs])
global_env = {"scalar_func": scalar_func} global_env = {"scalar_func": scalar_func}
scalar_op_src = f""" input_tmp_dtypes = None
if func_package == scipy and hasattr(scalar_func, "types"):
# The `numba-scipy` bindings don't provide implementations for all
# inputs types, so we need to convert the inputs to floats and back.
inp_dtype_kinds = tuple(np.dtype(inp.type.dtype).kind for inp in node.inputs)
accepted_inp_kinds = tuple(
sig_type.split("->")[0] for sig_type in scalar_func.types
)
if not any(
all(dk == ik for dk, ik in zip(inp_dtype_kinds, ok_kinds))
for ok_kinds in accepted_inp_kinds
):
# They're usually ordered from lower-to-higher precision, so
# we pick the last acceptable input types
#
# XXX: We should pick the first acceptable float/int types in
# reverse, excluding all the incompatible ones (e.g. `"0"`).
# The assumption is that this is only used by `numba-scipy`-exposed
# functions, although it's possible for this to be triggered by
# something else from the `scipy` package
input_tmp_dtypes = tuple(np.dtype(k) for k in accepted_inp_kinds[-1])
if input_tmp_dtypes is None:
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func"], suffix_sep="_"
)
input_names = ", ".join(
[unique_names(v, force_unique=True) for v in node.inputs]
)
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}): def {scalar_op_fn_name}({input_names}):
return scalar_func({input_names}) return scalar_func({input_names})
""" """
else:
global_env["direct_cast"] = numba_basic.direct_cast
global_env["output_dtype"] = np.dtype(node.outputs[0].type.dtype)
input_tmp_dtype_names = {
f"inp_tmp_dtype_{i}": i_dtype for i, i_dtype in enumerate(input_tmp_dtypes)
}
global_env.update(input_tmp_dtype_names)
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func"] + list(global_env.keys()), suffix_sep="_"
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
converted_call_args = ", ".join(
[
f"direct_cast({i_name}, {i_tmp_dtype_name})"
for i_name, i_tmp_dtype_name in zip(
input_names, input_tmp_dtype_names.keys()
)
]
)
scalar_op_src = f"""
def {scalar_op_fn_name}({', '.join(input_names)}):
return direct_cast(scalar_func({converted_call_args}), output_dtype)
"""
scalar_op_fn = compile_function_src( scalar_op_fn = compile_function_src(
scalar_op_src, scalar_op_fn_name, {**globals(), **global_env} scalar_op_src, scalar_op_fn_name, {**globals(), **global_env}
) )
......
...@@ -321,6 +321,12 @@ def test_numba_box_unbox(input, wrapper_fn, check_fn): ...@@ -321,6 +321,12 @@ def test_numba_box_unbox(input, wrapper_fn, check_fn):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inputs, input_vals, output_fn, exc", "inputs, input_vals, output_fn, exc",
[ [
(
[at.lvector()],
[rng.poisson(10, size=100).astype(np.int64)],
lambda x: at.gammaln(x),
None,
),
( (
[at.vector()], [at.vector()],
[rng.standard_normal(100).astype(config.floatX)], [rng.standard_normal(100).astype(config.floatX)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论