提交 5fbf81df authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Systematic use of mockable numba_basic.numba_jit

Direct import is not properly mocked by tests when trying to run `compare_numba_and_py` with `eval_obj_mode=True`
上级 d39ad599
......@@ -4,7 +4,8 @@ from typing import cast
from numba.core.extending import overload
from numba.np.unsafe.ndarray import to_fixed_tuple
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
......@@ -56,7 +57,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
src += f"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]}),"
src += ")"
to_tuple = numba_njit(
to_tuple = numba_basic.numba_njit(
compile_function_src(
src,
"to_tuple",
......
......@@ -359,13 +359,13 @@ def numba_funcify_Sum(op, node, **kwargs):
if ndim_input == len(axes):
# Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit
@numba_basic.numba_njit
def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
elif len(axes) == 0:
# These cases should be removed by rewrites!
@numba_njit
@numba_basic.numba_njit
def impl_sum(array):
return np.asarray(array, dtype=out_dtype)
......@@ -615,25 +615,25 @@ def numba_funcify_Dot(op, node, **kwargs):
if x_dtype == dot_dtype and y_dtype == dot_dtype:
@numba_njit
@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y))
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
@numba_njit
@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y.astype(dot_dtype)))
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
@numba_njit
@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y))
else:
@numba_njit()
@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
......@@ -642,7 +642,7 @@ def numba_funcify_Dot(op, node, **kwargs):
else:
@numba_njit
@numba_basic.numba_njit
def dot_with_cast(x, y):
return dot(x, y).astype(out_dtype)
......@@ -653,7 +653,7 @@ def numba_funcify_Dot(op, node, **kwargs):
def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype
@numba_njit
@numba_basic.numba_njit
def batched_dot(x, y):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
......
......@@ -2,16 +2,16 @@ from collections.abc import Callable
from typing import Literal
import numpy as np
from numba import njit as numba_njit
from numba.core.extending import overload
from numba.np.linalg import ensure_lapack
from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
@numba_njit
@numba_basic.numba_njit
def _pivot_to_permutation(p, dtype):
p_inv = np.arange(len(p)).astype(dtype)
for i in range(len(p)):
......@@ -19,7 +19,7 @@ def _pivot_to_permutation(p, dtype):
return p_inv
@numba_njit
@numba_basic.numba_njit
def _lu_factor_to_lu(a, dtype, overwrite_a):
A_copy, IPIV, _INFO = _getrf(a, overwrite_a=overwrite_a)
......
......@@ -6,8 +6,8 @@ from numba.np.linalg import ensure_lapack
from numpy import ndarray
from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
......@@ -27,7 +27,7 @@ from pytensor.tensor._linalg.solve.tridiagonal import (
)
@numba_njit
@numba_basic.numba_njit
def tridiagonal_norm(du, d, dl):
# Adapted from scipy _matrix_norm_tridiagonal:
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
......@@ -346,7 +346,7 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
overwrite_d = op.overwrite_d
overwrite_du = op.overwrite_du
@numba_njit(cache=False)
@numba_basic.numba_njit(cache=False)
def lu_factor_tridiagonal(dl, d, du):
dl, d, du, du2, ipiv, _ = _gttrf(
dl,
......@@ -368,7 +368,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b = op.overwrite_b
transposed = op.transposed
@numba_njit(cache=False)
@numba_basic.numba_njit(cache=False)
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
x, _ = _gttrs(
dl,
......
......@@ -30,14 +30,14 @@ def numba_funcify_SVD(op, node, **kwargs):
if not compute_uv:
@numba_basic.numba_njit()
@numba_basic.numba_njit
def svd(x):
_, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices)
return ret
else:
@numba_basic.numba_njit()
@numba_basic.numba_njit
def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices)
......
......@@ -91,7 +91,7 @@ def numba_core_rv_default(op, node):
def numba_core_BernoulliRV(op, node):
out_dtype = node.outputs[1].type.numpy_dtype
@numba_basic.numba_njit()
@numba_basic.numba_njit
def random(rng, p):
return (
direct_cast(0, out_dtype)
......
......@@ -3,6 +3,7 @@ from textwrap import dedent
import numpy as np
from numba.np.unsafe import ndarray as numba_ndarray
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit
from pytensor.link.utils import compile_function_src
......@@ -12,7 +13,7 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba_njit
@numba_basic.numba_njit
def shape(x):
return np.asarray(np.shape(x))
......@@ -23,7 +24,7 @@ def numba_funcify_Shape(op, **kwargs):
def numba_funcify_Shape_i(op, **kwargs):
i = op.i
@numba_njit
@numba_basic.numba_njit
def shape_i(x):
return np.asarray(np.shape(x)[i])
......@@ -61,13 +62,13 @@ def numba_funcify_Reshape(op, **kwargs):
if ndim == 0:
@numba_njit
@numba_basic.numba_njit
def reshape(x, shape):
return np.asarray(x.item())
else:
@numba_njit
@numba_basic.numba_njit
def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape(
......
import numpy as np
from numba.np.arraymath import _get_inner_prod
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.signal.conv import Convolve1d
......@@ -13,7 +13,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
out_dtype = node.outputs[0].type.dtype
innerprod = _get_inner_prod(a_dtype, b_dtype)
@numba_njit
@numba_basic.numba_njit
def valid_convolve1d(x, y):
nx = len(x)
ny = len(y)
......@@ -30,7 +30,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return ret
@numba_njit
@numba_basic.numba_njit
def full_convolve1d(x, y):
nx = len(x)
ny = len(y)
......@@ -59,7 +59,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return ret
@numba_njit
@numba_basic.numba_njit
def convolve_1d(x, y, mode):
if mode:
return full_convolve1d(x, y)
......
......@@ -3,7 +3,8 @@ import warnings
import numpy as np
from pytensor import config
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
_lu_1,
......@@ -63,7 +64,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit
@numba_basic.numba_njit
def cholesky(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
......@@ -95,7 +96,7 @@ def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse
dtype = node.outputs[0].dtype
@numba_njit
@numba_basic.numba_njit
def numba_pivot_to_permutation(piv):
p_inv = _pivot_to_permutation(piv, dtype)
......@@ -118,7 +119,7 @@ def numba_funcify_LU(op, node, **kwargs):
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit(inline="always")
@numba_basic.numba_njit(inline="always")
def lu(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
......@@ -165,7 +166,7 @@ def numba_funcify_LUFactor(op, node, **kwargs):
if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit
@numba_basic.numba_njit
def lu_factor(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
......@@ -185,7 +186,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_njit
@numba_basic.numba_njit
def block_diag(*arrs):
shapes = np.array([a.shape for a in arrs], dtype="int")
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
......@@ -235,7 +236,7 @@ def numba_funcify_Solve(op, node, **kwargs):
)
solve_fn = _solve_gen
@numba_njit
@numba_basic.numba_njit
def solve(a, b):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
......@@ -267,7 +268,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
)
@numba_njit
@numba_basic.numba_njit
def solve_triangular(a, b):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
......@@ -304,7 +305,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit
@numba_basic.numba_njit
def cho_solve(c, b):
if check_finite:
if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))):
......@@ -337,7 +338,7 @@ def numba_funcify_QR(op, node, **kwargs):
integer_input = dtype in integer_dtypes
in_dtype = config.floatX if integer_input else dtype
@numba_njit(cache=False)
@numba_basic.numba_njit(cache=False)
def qr(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
......
......@@ -2,8 +2,8 @@ import warnings
import numpy as np
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.sort import ArgSortOp, SortOp
......@@ -18,7 +18,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
UserWarning,
)
@numba_njit
@numba_basic.numba_njit
def sort_f(a, axis):
axis = axis.item()
......@@ -45,7 +45,7 @@ def numba_funcify_ArgSortOp(op, node, **kwargs):
UserWarning,
)
@numba_njit
@numba_basic.numba_njit
def argort_f(X, axis):
axis = axis.item()
......
......@@ -8,6 +8,7 @@ from numba import types
from numba.core.pythonapi import box
from pytensor.graph import Type
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator
......@@ -99,7 +100,7 @@ enable_slice_boxing()
@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba_njit
@numba_basic.numba_njit
def makeslice(*x):
return slice(*x)
......@@ -297,7 +298,7 @@ def numba_funcify_multiple_integer_vector_indexing(
if isinstance(op, AdvancedSubtensor):
@numba_njit
@numba_basic.numba_njit
def advanced_subtensor_multiple_vector(x, *idxs):
none_slices = idxs[:first_axis]
vec_idxs = idxs[first_axis:after_last_axis]
......@@ -328,7 +329,7 @@ def numba_funcify_multiple_integer_vector_indexing(
if op.set_instead_of_inc:
@numba_njit
@numba_basic.numba_njit
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
......@@ -350,7 +351,7 @@ def numba_funcify_multiple_integer_vector_indexing(
else:
@numba_njit
@numba_basic.numba_njit
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape
......@@ -382,7 +383,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if set_instead_of_inc:
if broadcast_with_index:
@numba_njit(boundscheck=True)
@numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
......@@ -398,7 +399,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else:
@numba_njit(boundscheck=True)
@numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
......@@ -409,7 +410,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else:
if broadcast_with_index:
@numba_njit(boundscheck=True)
@numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
......@@ -425,7 +426,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else:
@numba_njit(boundscheck=True)
@numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
......@@ -440,7 +441,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else:
@numba_njit
@numba_basic.numba_njit
def advancedincsubtensor1(x, vals, idxs):
x = x.copy()
return advancedincsubtensor1_inplace(x, vals, idxs)
......
......@@ -6,7 +6,6 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
numba_funcify,
numba_njit,
)
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor.basic import (
......@@ -243,7 +242,7 @@ def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba_funcify.register(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs):
@numba_njit
@numba_basic.numba_njit
def nonzero(a):
result_tuple = np.nonzero(a)
if a.ndim == 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论