Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5fbf81df
提交
5fbf81df
authored
10月 27, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 16, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Systematic use of mockable numba_basic.numba_jit
Direct import is not properly mocked by tests when trying to run `compare_numba_and_py` with `eval_obj_mode=True`
上级
d39ad599
隐藏空白字符变更
内嵌
并排
正在显示
12 个修改的文件
包含
55 行增加
和
52 行删除
+55
-52
blockwise.py
pytensor/link/numba/dispatch/blockwise.py
+3
-2
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+8
-8
lu.py
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
+3
-3
tridiagonal.py
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
+4
-4
nlinalg.py
pytensor/link/numba/dispatch/nlinalg.py
+2
-2
random.py
pytensor/link/numba/dispatch/random.py
+1
-1
shape.py
pytensor/link/numba/dispatch/shape.py
+5
-4
conv.py
pytensor/link/numba/dispatch/signal/conv.py
+4
-4
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+11
-10
sort.py
pytensor/link/numba/dispatch/sort.py
+3
-3
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+10
-9
tensor_basic.py
pytensor/link/numba/dispatch/tensor_basic.py
+1
-2
没有找到文件。
pytensor/link/numba/dispatch/blockwise.py
浏览文件 @
5fbf81df
...
...
@@ -4,7 +4,8 @@ from typing import cast
from
numba.core.extending
import
overload
from
numba.np.unsafe.ndarray
import
to_fixed_tuple
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.link.numba.dispatch.vectorize_codegen
import
(
_jit_options
,
_vectorized
,
...
...
@@ -56,7 +57,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
src
+=
f
"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]}),"
src
+=
")"
to_tuple
=
numba_njit
(
to_tuple
=
numba_
basic
.
numba_
njit
(
compile_function_src
(
src
,
"to_tuple"
,
...
...
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
5fbf81df
...
...
@@ -359,13 +359,13 @@ def numba_funcify_Sum(op, node, **kwargs):
if
ndim_input
==
len
(
axes
):
# Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit
@numba_
basic.numba_
njit
def
impl_sum
(
array
):
return
np
.
asarray
(
array
.
sum
(),
dtype
=
np_acc_dtype
)
.
astype
(
out_dtype
)
elif
len
(
axes
)
==
0
:
# These cases should be removed by rewrites!
@numba_njit
@numba_
basic.numba_
njit
def
impl_sum
(
array
):
return
np
.
asarray
(
array
,
dtype
=
out_dtype
)
...
...
@@ -615,25 +615,25 @@ def numba_funcify_Dot(op, node, **kwargs):
if
x_dtype
==
dot_dtype
and
y_dtype
==
dot_dtype
:
@numba_njit
@numba_
basic.numba_
njit
def
dot
(
x
,
y
):
return
np
.
asarray
(
np
.
dot
(
x
,
y
))
elif
x_dtype
==
dot_dtype
and
y_dtype
!=
dot_dtype
:
@numba_njit
@numba_
basic.numba_
njit
def
dot
(
x
,
y
):
return
np
.
asarray
(
np
.
dot
(
x
,
y
.
astype
(
dot_dtype
)))
elif
x_dtype
!=
dot_dtype
and
y_dtype
==
dot_dtype
:
@numba_njit
@numba_
basic.numba_
njit
def
dot
(
x
,
y
):
return
np
.
asarray
(
np
.
dot
(
x
.
astype
(
dot_dtype
),
y
))
else
:
@numba_
njit
()
@numba_
basic.numba_njit
def
dot
(
x
,
y
):
return
np
.
asarray
(
np
.
dot
(
x
.
astype
(
dot_dtype
),
y
.
astype
(
dot_dtype
)))
...
...
@@ -642,7 +642,7 @@ def numba_funcify_Dot(op, node, **kwargs):
else
:
@numba_njit
@numba_
basic.numba_
njit
def
dot_with_cast
(
x
,
y
):
return
dot
(
x
,
y
)
.
astype
(
out_dtype
)
...
...
@@ -653,7 +653,7 @@ def numba_funcify_Dot(op, node, **kwargs):
def
numba_funcify_BatchedDot
(
op
,
node
,
**
kwargs
):
dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
@numba_njit
@numba_
basic.numba_
njit
def
batched_dot
(
x
,
y
):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
浏览文件 @
5fbf81df
...
...
@@ -2,16 +2,16 @@ from collections.abc import Callable
from
typing
import
Literal
import
numpy
as
np
from
numba
import
njit
as
numba_njit
from
numba.core.extending
import
overload
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_scipy_linalg_matrix
@numba_njit
@numba_
basic.numba_
njit
def
_pivot_to_permutation
(
p
,
dtype
):
p_inv
=
np
.
arange
(
len
(
p
))
.
astype
(
dtype
)
for
i
in
range
(
len
(
p
)):
...
...
@@ -19,7 +19,7 @@ def _pivot_to_permutation(p, dtype):
return
p_inv
@numba_njit
@numba_
basic.numba_
njit
def
_lu_factor_to_lu
(
a
,
dtype
,
overwrite_a
):
A_copy
,
IPIV
,
_INFO
=
_getrf
(
a
,
overwrite_a
=
overwrite_a
)
...
...
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
浏览文件 @
5fbf81df
...
...
@@ -6,8 +6,8 @@ from numba.np.linalg import ensure_lapack
from
numpy
import
ndarray
from
scipy
import
linalg
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
numba_njit
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
...
...
@@ -27,7 +27,7 @@ from pytensor.tensor._linalg.solve.tridiagonal import (
)
@numba_njit
@numba_
basic.numba_
njit
def
tridiagonal_norm
(
du
,
d
,
dl
):
# Adapted from scipy _matrix_norm_tridiagonal:
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
...
...
@@ -346,7 +346,7 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
overwrite_d
=
op
.
overwrite_d
overwrite_du
=
op
.
overwrite_du
@numba_njit
(
cache
=
False
)
@numba_
basic.numba_
njit
(
cache
=
False
)
def
lu_factor_tridiagonal
(
dl
,
d
,
du
):
dl
,
d
,
du
,
du2
,
ipiv
,
_
=
_gttrf
(
dl
,
...
...
@@ -368,7 +368,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b
=
op
.
overwrite_b
transposed
=
op
.
transposed
@numba_njit
(
cache
=
False
)
@numba_
basic.numba_
njit
(
cache
=
False
)
def
solve_lu_factor_tridiagonal
(
dl
,
d
,
du
,
du2
,
ipiv
,
b
):
x
,
_
=
_gttrs
(
dl
,
...
...
pytensor/link/numba/dispatch/nlinalg.py
浏览文件 @
5fbf81df
...
...
@@ -30,14 +30,14 @@ def numba_funcify_SVD(op, node, **kwargs):
if
not
compute_uv
:
@numba_basic.numba_njit
()
@numba_basic.numba_njit
def
svd
(
x
):
_
,
ret
,
_
=
np
.
linalg
.
svd
(
inputs_cast
(
x
),
full_matrices
)
return
ret
else
:
@numba_basic.numba_njit
()
@numba_basic.numba_njit
def
svd
(
x
):
return
np
.
linalg
.
svd
(
inputs_cast
(
x
),
full_matrices
)
...
...
pytensor/link/numba/dispatch/random.py
浏览文件 @
5fbf81df
...
...
@@ -91,7 +91,7 @@ def numba_core_rv_default(op, node):
def
numba_core_BernoulliRV
(
op
,
node
):
out_dtype
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
@numba_basic.numba_njit
()
@numba_basic.numba_njit
def
random
(
rng
,
p
):
return
(
direct_cast
(
0
,
out_dtype
)
...
...
pytensor/link/numba/dispatch/shape.py
浏览文件 @
5fbf81df
...
...
@@ -3,6 +3,7 @@ from textwrap import dedent
import
numpy
as
np
from
numba.np.unsafe
import
ndarray
as
numba_ndarray
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
create_arg_string
,
numba_njit
from
pytensor.link.utils
import
compile_function_src
...
...
@@ -12,7 +13,7 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@numba_funcify.register
(
Shape
)
def
numba_funcify_Shape
(
op
,
**
kwargs
):
@numba_njit
@numba_
basic.numba_
njit
def
shape
(
x
):
return
np
.
asarray
(
np
.
shape
(
x
))
...
...
@@ -23,7 +24,7 @@ def numba_funcify_Shape(op, **kwargs):
def
numba_funcify_Shape_i
(
op
,
**
kwargs
):
i
=
op
.
i
@numba_njit
@numba_
basic.numba_
njit
def
shape_i
(
x
):
return
np
.
asarray
(
np
.
shape
(
x
)[
i
])
...
...
@@ -61,13 +62,13 @@ def numba_funcify_Reshape(op, **kwargs):
if
ndim
==
0
:
@numba_njit
@numba_
basic.numba_
njit
def
reshape
(
x
,
shape
):
return
np
.
asarray
(
x
.
item
())
else
:
@numba_njit
@numba_
basic.numba_
njit
def
reshape
(
x
,
shape
):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return
np
.
reshape
(
...
...
pytensor/link/numba/dispatch/signal/conv.py
浏览文件 @
5fbf81df
import
numpy
as
np
from
numba.np.arraymath
import
_get_inner_prod
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
numba_njit
from
pytensor.tensor.signal.conv
import
Convolve1d
...
...
@@ -13,7 +13,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
innerprod
=
_get_inner_prod
(
a_dtype
,
b_dtype
)
@numba_njit
@numba_
basic.numba_
njit
def
valid_convolve1d
(
x
,
y
):
nx
=
len
(
x
)
ny
=
len
(
y
)
...
...
@@ -30,7 +30,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return
ret
@numba_njit
@numba_
basic.numba_
njit
def
full_convolve1d
(
x
,
y
):
nx
=
len
(
x
)
ny
=
len
(
y
)
...
...
@@ -59,7 +59,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return
ret
@numba_njit
@numba_
basic.numba_
njit
def
convolve_1d
(
x
,
y
,
mode
):
if
mode
:
return
full_convolve1d
(
x
,
y
)
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
5fbf81df
...
...
@@ -3,7 +3,8 @@ import warnings
import
numpy
as
np
from
pytensor
import
config
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
_lu_1
,
...
...
@@ -63,7 +64,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_njit
@numba_
basic.numba_
njit
def
cholesky
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
@@ -95,7 +96,7 @@ def pivot_to_permutation(op, node, **kwargs):
inverse
=
op
.
inverse
dtype
=
node
.
outputs
[
0
]
.
dtype
@numba_njit
@numba_
basic.numba_
njit
def
numba_pivot_to_permutation
(
piv
):
p_inv
=
_pivot_to_permutation
(
piv
,
dtype
)
...
...
@@ -118,7 +119,7 @@ def numba_funcify_LU(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_njit
(
inline
=
"always"
)
@numba_
basic.numba_
njit
(
inline
=
"always"
)
def
lu
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
@@ -165,7 +166,7 @@ def numba_funcify_LUFactor(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_njit
@numba_
basic.numba_
njit
def
lu_factor
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
@@ -185,7 +186,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype
=
node
.
outputs
[
0
]
.
dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_njit
@numba_
basic.numba_
njit
def
block_diag
(
*
arrs
):
shapes
=
np
.
array
([
a
.
shape
for
a
in
arrs
],
dtype
=
"int"
)
out_shape
=
[
int
(
s
)
for
s
in
np
.
sum
(
shapes
,
axis
=
0
)]
...
...
@@ -235,7 +236,7 @@ def numba_funcify_Solve(op, node, **kwargs):
)
solve_fn
=
_solve_gen
@numba_njit
@numba_
basic.numba_
njit
def
solve
(
a
,
b
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
@@ -267,7 +268,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
"Solve Triangular"
)
)
@numba_njit
@numba_
basic.numba_
njit
def
solve_triangular
(
a
,
b
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
@@ -304,7 +305,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_njit
@numba_
basic.numba_
njit
def
cho_solve
(
c
,
b
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
c
),
np
.
isnan
(
c
))):
...
...
@@ -337,7 +338,7 @@ def numba_funcify_QR(op, node, **kwargs):
integer_input
=
dtype
in
integer_dtypes
in_dtype
=
config
.
floatX
if
integer_input
else
dtype
@numba_njit
(
cache
=
False
)
@numba_
basic.numba_
njit
(
cache
=
False
)
def
qr
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
pytensor/link/numba/dispatch/sort.py
浏览文件 @
5fbf81df
...
...
@@ -2,8 +2,8 @@ import warnings
import
numpy
as
np
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
numba_njit
from
pytensor.tensor.sort
import
ArgSortOp
,
SortOp
...
...
@@ -18,7 +18,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
UserWarning
,
)
@numba_njit
@numba_
basic.numba_
njit
def
sort_f
(
a
,
axis
):
axis
=
axis
.
item
()
...
...
@@ -45,7 +45,7 @@ def numba_funcify_ArgSortOp(op, node, **kwargs):
UserWarning
,
)
@numba_njit
@numba_
basic.numba_
njit
def
argort_f
(
X
,
axis
):
axis
=
axis
.
item
()
...
...
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
5fbf81df
...
...
@@ -8,6 +8,7 @@ from numba import types
from
numba.core.pythonapi
import
box
from
pytensor.graph
import
Type
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
generate_fallback_impl
,
numba_njit
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
...
...
@@ -99,7 +100,7 @@ enable_slice_boxing()
@numba_funcify.register
(
MakeSlice
)
def
numba_funcify_MakeSlice
(
op
,
**
kwargs
):
@numba_njit
@numba_
basic.numba_
njit
def
makeslice
(
*
x
):
return
slice
(
*
x
)
...
...
@@ -297,7 +298,7 @@ def numba_funcify_multiple_integer_vector_indexing(
if
isinstance
(
op
,
AdvancedSubtensor
):
@numba_njit
@numba_
basic.numba_
njit
def
advanced_subtensor_multiple_vector
(
x
,
*
idxs
):
none_slices
=
idxs
[:
first_axis
]
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
...
...
@@ -328,7 +329,7 @@ def numba_funcify_multiple_integer_vector_indexing(
if
op
.
set_instead_of_inc
:
@numba_njit
@numba_
basic.numba_
njit
def
advanced_set_subtensor_multiple_vector
(
x
,
y
,
*
idxs
):
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
x_shape
=
x
.
shape
...
...
@@ -350,7 +351,7 @@ def numba_funcify_multiple_integer_vector_indexing(
else
:
@numba_njit
@numba_
basic.numba_
njit
def
advanced_inc_subtensor_multiple_vector
(
x
,
y
,
*
idxs
):
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
x_shape
=
x
.
shape
...
...
@@ -382,7 +383,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if
set_instead_of_inc
:
if
broadcast_with_index
:
@numba_njit
(
boundscheck
=
True
)
@numba_
basic.numba_
njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
val
,
idxs
):
if
val
.
ndim
==
x
.
ndim
:
core_val
=
val
[
0
]
...
...
@@ -398,7 +399,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else
:
@numba_njit
(
boundscheck
=
True
)
@numba_
basic.numba_
njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
):
if
not
len
(
idxs
)
==
len
(
vals
):
raise
ValueError
(
"The number of indices and values must match."
)
...
...
@@ -409,7 +410,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else
:
if
broadcast_with_index
:
@numba_njit
(
boundscheck
=
True
)
@numba_
basic.numba_
njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
val
,
idxs
):
if
val
.
ndim
==
x
.
ndim
:
core_val
=
val
[
0
]
...
...
@@ -425,7 +426,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else
:
@numba_njit
(
boundscheck
=
True
)
@numba_
basic.numba_
njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
):
if
not
len
(
idxs
)
==
len
(
vals
):
raise
ValueError
(
"The number of indices and values must match."
)
...
...
@@ -440,7 +441,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else
:
@numba_njit
@numba_
basic.numba_
njit
def
advancedincsubtensor1
(
x
,
vals
,
idxs
):
x
=
x
.
copy
()
return
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
)
...
...
pytensor/link/numba/dispatch/tensor_basic.py
浏览文件 @
5fbf81df
...
...
@@ -6,7 +6,6 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
create_tuple_string
,
numba_funcify
,
numba_njit
,
)
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
from
pytensor.tensor.basic
import
(
...
...
@@ -243,7 +242,7 @@ def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba_funcify.register
(
Nonzero
)
def
numba_funcify_Nonzero
(
op
,
node
,
**
kwargs
):
@numba_njit
@numba_
basic.numba_
njit
def
nonzero
(
a
):
result_tuple
=
np
.
nonzero
(
a
)
if
a
.
ndim
==
1
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论