提交 5fbf81df authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Systematic use of mockable numba_basic.numba_jit

Direct import is not properly mocked by tests when trying to run `compare_numba_and_py` with `eval_obj_mode=True`
上级 d39ad599
...@@ -4,7 +4,8 @@ from typing import cast ...@@ -4,7 +4,8 @@ from typing import cast
from numba.core.extending import overload from numba.core.extending import overload
from numba.np.unsafe.ndarray import to_fixed_tuple from numba.np.unsafe.ndarray import to_fixed_tuple
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.link.numba.dispatch.vectorize_codegen import ( from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options, _jit_options,
_vectorized, _vectorized,
...@@ -56,7 +57,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): ...@@ -56,7 +57,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
src += f"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]})," src += f"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]}),"
src += ")" src += ")"
to_tuple = numba_njit( to_tuple = numba_basic.numba_njit(
compile_function_src( compile_function_src(
src, src,
"to_tuple", "to_tuple",
......
...@@ -359,13 +359,13 @@ def numba_funcify_Sum(op, node, **kwargs): ...@@ -359,13 +359,13 @@ def numba_funcify_Sum(op, node, **kwargs):
if ndim_input == len(axes): if ndim_input == len(axes):
# Slightly faster than `numba_funcify_CAReduce` for this case # Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit @numba_basic.numba_njit
def impl_sum(array): def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
elif len(axes) == 0: elif len(axes) == 0:
# These cases should be removed by rewrites! # These cases should be removed by rewrites!
@numba_njit @numba_basic.numba_njit
def impl_sum(array): def impl_sum(array):
return np.asarray(array, dtype=out_dtype) return np.asarray(array, dtype=out_dtype)
...@@ -615,25 +615,25 @@ def numba_funcify_Dot(op, node, **kwargs): ...@@ -615,25 +615,25 @@ def numba_funcify_Dot(op, node, **kwargs):
if x_dtype == dot_dtype and y_dtype == dot_dtype: if x_dtype == dot_dtype and y_dtype == dot_dtype:
@numba_njit @numba_basic.numba_njit
def dot(x, y): def dot(x, y):
return np.asarray(np.dot(x, y)) return np.asarray(np.dot(x, y))
elif x_dtype == dot_dtype and y_dtype != dot_dtype: elif x_dtype == dot_dtype and y_dtype != dot_dtype:
@numba_njit @numba_basic.numba_njit
def dot(x, y): def dot(x, y):
return np.asarray(np.dot(x, y.astype(dot_dtype))) return np.asarray(np.dot(x, y.astype(dot_dtype)))
elif x_dtype != dot_dtype and y_dtype == dot_dtype: elif x_dtype != dot_dtype and y_dtype == dot_dtype:
@numba_njit @numba_basic.numba_njit
def dot(x, y): def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y)) return np.asarray(np.dot(x.astype(dot_dtype), y))
else: else:
@numba_njit() @numba_basic.numba_njit
def dot(x, y): def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype))) return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
...@@ -642,7 +642,7 @@ def numba_funcify_Dot(op, node, **kwargs): ...@@ -642,7 +642,7 @@ def numba_funcify_Dot(op, node, **kwargs):
else: else:
@numba_njit @numba_basic.numba_njit
def dot_with_cast(x, y): def dot_with_cast(x, y):
return dot(x, y).astype(out_dtype) return dot(x, y).astype(out_dtype)
...@@ -653,7 +653,7 @@ def numba_funcify_Dot(op, node, **kwargs): ...@@ -653,7 +653,7 @@ def numba_funcify_Dot(op, node, **kwargs):
def numba_funcify_BatchedDot(op, node, **kwargs): def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype dtype = node.outputs[0].type.numpy_dtype
@numba_njit @numba_basic.numba_njit
def batched_dot(x, y): def batched_dot(x, y):
# Numba does not support 3D matmul # Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804 # https://github.com/numba/numba/issues/3804
......
...@@ -2,16 +2,16 @@ from collections.abc import Callable ...@@ -2,16 +2,16 @@ from collections.abc import Callable
from typing import Literal from typing import Literal
import numpy as np import numpy as np
from numba import njit as numba_njit
from numba.core.extending import overload from numba.core.extending import overload
from numba.np.linalg import ensure_lapack from numba.np.linalg import ensure_lapack
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
@numba_njit @numba_basic.numba_njit
def _pivot_to_permutation(p, dtype): def _pivot_to_permutation(p, dtype):
p_inv = np.arange(len(p)).astype(dtype) p_inv = np.arange(len(p)).astype(dtype)
for i in range(len(p)): for i in range(len(p)):
...@@ -19,7 +19,7 @@ def _pivot_to_permutation(p, dtype): ...@@ -19,7 +19,7 @@ def _pivot_to_permutation(p, dtype):
return p_inv return p_inv
@numba_njit @numba_basic.numba_njit
def _lu_factor_to_lu(a, dtype, overwrite_a): def _lu_factor_to_lu(a, dtype, overwrite_a):
A_copy, IPIV, _INFO = _getrf(a, overwrite_a=overwrite_a) A_copy, IPIV, _INFO = _getrf(a, overwrite_a=overwrite_a)
......
...@@ -6,8 +6,8 @@ from numba.np.linalg import ensure_lapack ...@@ -6,8 +6,8 @@ from numba.np.linalg import ensure_lapack
from numpy import ndarray from numpy import ndarray
from scipy import linalg from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
_get_underlying_float, _get_underlying_float,
...@@ -27,7 +27,7 @@ from pytensor.tensor._linalg.solve.tridiagonal import ( ...@@ -27,7 +27,7 @@ from pytensor.tensor._linalg.solve.tridiagonal import (
) )
@numba_njit @numba_basic.numba_njit
def tridiagonal_norm(du, d, dl): def tridiagonal_norm(du, d, dl):
# Adapted from scipy _matrix_norm_tridiagonal: # Adapted from scipy _matrix_norm_tridiagonal:
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367 # https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
...@@ -346,7 +346,7 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): ...@@ -346,7 +346,7 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
overwrite_d = op.overwrite_d overwrite_d = op.overwrite_d
overwrite_du = op.overwrite_du overwrite_du = op.overwrite_du
@numba_njit(cache=False) @numba_basic.numba_njit(cache=False)
def lu_factor_tridiagonal(dl, d, du): def lu_factor_tridiagonal(dl, d, du):
dl, d, du, du2, ipiv, _ = _gttrf( dl, d, du, du2, ipiv, _ = _gttrf(
dl, dl,
...@@ -368,7 +368,7 @@ def numba_funcify_SolveLUFactorTridiagonal( ...@@ -368,7 +368,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b = op.overwrite_b overwrite_b = op.overwrite_b
transposed = op.transposed transposed = op.transposed
@numba_njit(cache=False) @numba_basic.numba_njit(cache=False)
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b): def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
x, _ = _gttrs( x, _ = _gttrs(
dl, dl,
......
...@@ -30,14 +30,14 @@ def numba_funcify_SVD(op, node, **kwargs): ...@@ -30,14 +30,14 @@ def numba_funcify_SVD(op, node, **kwargs):
if not compute_uv: if not compute_uv:
@numba_basic.numba_njit() @numba_basic.numba_njit
def svd(x): def svd(x):
_, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices) _, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices)
return ret return ret
else: else:
@numba_basic.numba_njit() @numba_basic.numba_njit
def svd(x): def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices) return np.linalg.svd(inputs_cast(x), full_matrices)
......
...@@ -91,7 +91,7 @@ def numba_core_rv_default(op, node): ...@@ -91,7 +91,7 @@ def numba_core_rv_default(op, node):
def numba_core_BernoulliRV(op, node): def numba_core_BernoulliRV(op, node):
out_dtype = node.outputs[1].type.numpy_dtype out_dtype = node.outputs[1].type.numpy_dtype
@numba_basic.numba_njit() @numba_basic.numba_njit
def random(rng, p): def random(rng, p):
return ( return (
direct_cast(0, out_dtype) direct_cast(0, out_dtype)
......
...@@ -3,6 +3,7 @@ from textwrap import dedent ...@@ -3,6 +3,7 @@ from textwrap import dedent
import numpy as np import numpy as np
from numba.np.unsafe import ndarray as numba_ndarray from numba.np.unsafe import ndarray as numba_ndarray
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit
from pytensor.link.utils import compile_function_src from pytensor.link.utils import compile_function_src
...@@ -12,7 +13,7 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape ...@@ -12,7 +13,7 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@numba_funcify.register(Shape) @numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs): def numba_funcify_Shape(op, **kwargs):
@numba_njit @numba_basic.numba_njit
def shape(x): def shape(x):
return np.asarray(np.shape(x)) return np.asarray(np.shape(x))
...@@ -23,7 +24,7 @@ def numba_funcify_Shape(op, **kwargs): ...@@ -23,7 +24,7 @@ def numba_funcify_Shape(op, **kwargs):
def numba_funcify_Shape_i(op, **kwargs): def numba_funcify_Shape_i(op, **kwargs):
i = op.i i = op.i
@numba_njit @numba_basic.numba_njit
def shape_i(x): def shape_i(x):
return np.asarray(np.shape(x)[i]) return np.asarray(np.shape(x)[i])
...@@ -61,13 +62,13 @@ def numba_funcify_Reshape(op, **kwargs): ...@@ -61,13 +62,13 @@ def numba_funcify_Reshape(op, **kwargs):
if ndim == 0: if ndim == 0:
@numba_njit @numba_basic.numba_njit
def reshape(x, shape): def reshape(x, shape):
return np.asarray(x.item()) return np.asarray(x.item())
else: else:
@numba_njit @numba_basic.numba_njit
def reshape(x, shape): def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape( return np.reshape(
......
import numpy as np import numpy as np
from numba.np.arraymath import _get_inner_prod from numba.np.arraymath import _get_inner_prod
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.signal.conv import Convolve1d from pytensor.tensor.signal.conv import Convolve1d
...@@ -13,7 +13,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs): ...@@ -13,7 +13,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
out_dtype = node.outputs[0].type.dtype out_dtype = node.outputs[0].type.dtype
innerprod = _get_inner_prod(a_dtype, b_dtype) innerprod = _get_inner_prod(a_dtype, b_dtype)
@numba_njit @numba_basic.numba_njit
def valid_convolve1d(x, y): def valid_convolve1d(x, y):
nx = len(x) nx = len(x)
ny = len(y) ny = len(y)
...@@ -30,7 +30,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs): ...@@ -30,7 +30,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return ret return ret
@numba_njit @numba_basic.numba_njit
def full_convolve1d(x, y): def full_convolve1d(x, y):
nx = len(x) nx = len(x)
ny = len(y) ny = len(y)
...@@ -59,7 +59,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs): ...@@ -59,7 +59,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return ret return ret
@numba_njit @numba_basic.numba_njit
def convolve_1d(x, y, mode): def convolve_1d(x, y, mode):
if mode: if mode:
return full_convolve1d(x, y) return full_convolve1d(x, y)
......
...@@ -3,7 +3,8 @@ import warnings ...@@ -3,7 +3,8 @@ import warnings
import numpy as np import numpy as np
from pytensor import config from pytensor import config
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
from pytensor.link.numba.dispatch.linalg.decomposition.lu import ( from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
_lu_1, _lu_1,
...@@ -63,7 +64,7 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -63,7 +64,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
if dtype in complex_dtypes: if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit @numba_basic.numba_njit
def cholesky(a): def cholesky(a):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
...@@ -95,7 +96,7 @@ def pivot_to_permutation(op, node, **kwargs): ...@@ -95,7 +96,7 @@ def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse inverse = op.inverse
dtype = node.outputs[0].dtype dtype = node.outputs[0].dtype
@numba_njit @numba_basic.numba_njit
def numba_pivot_to_permutation(piv): def numba_pivot_to_permutation(piv):
p_inv = _pivot_to_permutation(piv, dtype) p_inv = _pivot_to_permutation(piv, dtype)
...@@ -118,7 +119,7 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -118,7 +119,7 @@ def numba_funcify_LU(op, node, **kwargs):
if dtype in complex_dtypes: if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def lu(a): def lu(a):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
...@@ -165,7 +166,7 @@ def numba_funcify_LUFactor(op, node, **kwargs): ...@@ -165,7 +166,7 @@ def numba_funcify_LUFactor(op, node, **kwargs):
if dtype in complex_dtypes: if dtype in complex_dtypes:
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit @numba_basic.numba_njit
def lu_factor(a): def lu_factor(a):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
...@@ -185,7 +186,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): ...@@ -185,7 +186,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype dtype = node.outputs[0].dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case. # TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_njit @numba_basic.numba_njit
def block_diag(*arrs): def block_diag(*arrs):
shapes = np.array([a.shape for a in arrs], dtype="int") shapes = np.array([a.shape for a in arrs], dtype="int")
out_shape = [int(s) for s in np.sum(shapes, axis=0)] out_shape = [int(s) for s in np.sum(shapes, axis=0)]
...@@ -235,7 +236,7 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -235,7 +236,7 @@ def numba_funcify_Solve(op, node, **kwargs):
) )
solve_fn = _solve_gen solve_fn = _solve_gen
@numba_njit @numba_basic.numba_njit
def solve(a, b): def solve(a, b):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
...@@ -267,7 +268,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -267,7 +268,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular") _COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
) )
@numba_njit @numba_basic.numba_njit
def solve_triangular(a, b): def solve_triangular(a, b):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
...@@ -304,7 +305,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -304,7 +305,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
if dtype in complex_dtypes: if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_njit @numba_basic.numba_njit
def cho_solve(c, b): def cho_solve(c, b):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))): if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))):
...@@ -337,7 +338,7 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -337,7 +338,7 @@ def numba_funcify_QR(op, node, **kwargs):
integer_input = dtype in integer_dtypes integer_input = dtype in integer_dtypes
in_dtype = config.floatX if integer_input else dtype in_dtype = config.floatX if integer_input else dtype
@numba_njit(cache=False) @numba_basic.numba_njit(cache=False)
def qr(a): def qr(a):
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
......
...@@ -2,8 +2,8 @@ import warnings ...@@ -2,8 +2,8 @@ import warnings
import numpy as np import numpy as np
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.sort import ArgSortOp, SortOp
...@@ -18,7 +18,7 @@ def numba_funcify_SortOp(op, node, **kwargs): ...@@ -18,7 +18,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
UserWarning, UserWarning,
) )
@numba_njit @numba_basic.numba_njit
def sort_f(a, axis): def sort_f(a, axis):
axis = axis.item() axis = axis.item()
...@@ -45,7 +45,7 @@ def numba_funcify_ArgSortOp(op, node, **kwargs): ...@@ -45,7 +45,7 @@ def numba_funcify_ArgSortOp(op, node, **kwargs):
UserWarning, UserWarning,
) )
@numba_njit @numba_basic.numba_njit
def argort_f(X, axis): def argort_f(X, axis):
axis = axis.item() axis = axis.item()
......
...@@ -8,6 +8,7 @@ from numba import types ...@@ -8,6 +8,7 @@ from numba import types
from numba.core.pythonapi import box from numba.core.pythonapi import box
from pytensor.graph import Type from pytensor.graph import Type
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.link.utils import compile_function_src, unique_name_generator
...@@ -99,7 +100,7 @@ enable_slice_boxing() ...@@ -99,7 +100,7 @@ enable_slice_boxing()
@numba_funcify.register(MakeSlice) @numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs): def numba_funcify_MakeSlice(op, **kwargs):
@numba_njit @numba_basic.numba_njit
def makeslice(*x): def makeslice(*x):
return slice(*x) return slice(*x)
...@@ -297,7 +298,7 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -297,7 +298,7 @@ def numba_funcify_multiple_integer_vector_indexing(
if isinstance(op, AdvancedSubtensor): if isinstance(op, AdvancedSubtensor):
@numba_njit @numba_basic.numba_njit
def advanced_subtensor_multiple_vector(x, *idxs): def advanced_subtensor_multiple_vector(x, *idxs):
none_slices = idxs[:first_axis] none_slices = idxs[:first_axis]
vec_idxs = idxs[first_axis:after_last_axis] vec_idxs = idxs[first_axis:after_last_axis]
...@@ -328,7 +329,7 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -328,7 +329,7 @@ def numba_funcify_multiple_integer_vector_indexing(
if op.set_instead_of_inc: if op.set_instead_of_inc:
@numba_njit @numba_basic.numba_njit
def advanced_set_subtensor_multiple_vector(x, y, *idxs): def advanced_set_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis] vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape x_shape = x.shape
...@@ -350,7 +351,7 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -350,7 +351,7 @@ def numba_funcify_multiple_integer_vector_indexing(
else: else:
@numba_njit @numba_basic.numba_njit
def advanced_inc_subtensor_multiple_vector(x, y, *idxs): def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis] vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape x_shape = x.shape
...@@ -382,7 +383,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -382,7 +383,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if set_instead_of_inc: if set_instead_of_inc:
if broadcast_with_index: if broadcast_with_index:
@numba_njit(boundscheck=True) @numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs): def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim: if val.ndim == x.ndim:
core_val = val[0] core_val = val[0]
...@@ -398,7 +399,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -398,7 +399,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else: else:
@numba_njit(boundscheck=True) @numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs): def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals): if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.") raise ValueError("The number of indices and values must match.")
...@@ -409,7 +410,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -409,7 +410,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else: else:
if broadcast_with_index: if broadcast_with_index:
@numba_njit(boundscheck=True) @numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs): def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim: if val.ndim == x.ndim:
core_val = val[0] core_val = val[0]
...@@ -425,7 +426,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -425,7 +426,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else: else:
@numba_njit(boundscheck=True) @numba_basic.numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, vals, idxs): def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals): if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.") raise ValueError("The number of indices and values must match.")
...@@ -440,7 +441,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -440,7 +441,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else: else:
@numba_njit @numba_basic.numba_njit
def advancedincsubtensor1(x, vals, idxs): def advancedincsubtensor1(x, vals, idxs):
x = x.copy() x = x.copy()
return advancedincsubtensor1_inplace(x, vals, idxs) return advancedincsubtensor1_inplace(x, vals, idxs)
......
...@@ -6,7 +6,6 @@ from pytensor.link.numba.dispatch import basic as numba_basic ...@@ -6,7 +6,6 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_tuple_string, create_tuple_string,
numba_funcify, numba_funcify,
numba_njit,
) )
from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
...@@ -243,7 +242,7 @@ def numba_funcify_ScalarFromTensor(op, **kwargs): ...@@ -243,7 +242,7 @@ def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba_funcify.register(Nonzero) @numba_funcify.register(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs): def numba_funcify_Nonzero(op, node, **kwargs):
@numba_njit @numba_basic.numba_njit
def nonzero(a): def nonzero(a):
result_tuple = np.nonzero(a) result_tuple = np.nonzero(a)
if a.ndim == 1: if a.ndim == 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论