Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5fbf81df
提交
5fbf81df
authored
10月 27, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 16, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Systematic use of mockable numba_basic.numba_jit
Direct import is not properly mocked by tests when trying to run `compare_numba_and_py` with `eval_obj_mode=True`
上级
d39ad599
隐藏空白字符变更
内嵌
并排
正在显示
12 个修改的文件
包含
55 行增加
和
52 行删除
+55
-52
blockwise.py
pytensor/link/numba/dispatch/blockwise.py
+3
-2
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+8
-8
lu.py
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
+3
-3
tridiagonal.py
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
+4
-4
nlinalg.py
pytensor/link/numba/dispatch/nlinalg.py
+2
-2
random.py
pytensor/link/numba/dispatch/random.py
+1
-1
shape.py
pytensor/link/numba/dispatch/shape.py
+5
-4
conv.py
pytensor/link/numba/dispatch/signal/conv.py
+4
-4
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+11
-10
sort.py
pytensor/link/numba/dispatch/sort.py
+3
-3
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+10
-9
tensor_basic.py
pytensor/link/numba/dispatch/tensor_basic.py
+1
-2
没有找到文件。
pytensor/link/numba/dispatch/blockwise.py
浏览文件 @
5fbf81df
...
@@ -4,7 +4,8 @@ from typing import cast
...
@@ -4,7 +4,8 @@ from typing import cast
from
numba.core.extending
import
overload
from
numba.core.extending
import
overload
from
numba.np.unsafe.ndarray
import
to_fixed_tuple
from
numba.np.unsafe.ndarray
import
to_fixed_tuple
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.link.numba.dispatch.vectorize_codegen
import
(
from
pytensor.link.numba.dispatch.vectorize_codegen
import
(
_jit_options
,
_jit_options
,
_vectorized
,
_vectorized
,
...
@@ -56,7 +57,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
...
@@ -56,7 +57,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
src
+=
f
"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]}),"
src
+=
f
"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]}),"
src
+=
")"
src
+=
")"
to_tuple
=
numba_njit
(
to_tuple
=
numba_
basic
.
numba_
njit
(
compile_function_src
(
compile_function_src
(
src
,
src
,
"to_tuple"
,
"to_tuple"
,
...
...
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
5fbf81df
...
@@ -359,13 +359,13 @@ def numba_funcify_Sum(op, node, **kwargs):
...
@@ -359,13 +359,13 @@ def numba_funcify_Sum(op, node, **kwargs):
if
ndim_input
==
len
(
axes
):
if
ndim_input
==
len
(
axes
):
# Slightly faster than `numba_funcify_CAReduce` for this case
# Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit
@numba_
basic.numba_
njit
def
impl_sum
(
array
):
def
impl_sum
(
array
):
return
np
.
asarray
(
array
.
sum
(),
dtype
=
np_acc_dtype
)
.
astype
(
out_dtype
)
return
np
.
asarray
(
array
.
sum
(),
dtype
=
np_acc_dtype
)
.
astype
(
out_dtype
)
elif
len
(
axes
)
==
0
:
elif
len
(
axes
)
==
0
:
# These cases should be removed by rewrites!
# These cases should be removed by rewrites!
@numba_njit
@numba_
basic.numba_
njit
def
impl_sum
(
array
):
def
impl_sum
(
array
):
return
np
.
asarray
(
array
,
dtype
=
out_dtype
)
return
np
.
asarray
(
array
,
dtype
=
out_dtype
)
...
@@ -615,25 +615,25 @@ def numba_funcify_Dot(op, node, **kwargs):
...
@@ -615,25 +615,25 @@ def numba_funcify_Dot(op, node, **kwargs):
if
x_dtype
==
dot_dtype
and
y_dtype
==
dot_dtype
:
if
x_dtype
==
dot_dtype
and
y_dtype
==
dot_dtype
:
@numba_njit
@numba_
basic.numba_
njit
def
dot
(
x
,
y
):
def
dot
(
x
,
y
):
return
np
.
asarray
(
np
.
dot
(
x
,
y
))
return
np
.
asarray
(
np
.
dot
(
x
,
y
))
elif
x_dtype
==
dot_dtype
and
y_dtype
!=
dot_dtype
:
elif
x_dtype
==
dot_dtype
and
y_dtype
!=
dot_dtype
:
@numba_njit
@numba_
basic.numba_
njit
def
dot
(
x
,
y
):
def
dot
(
x
,
y
):
return
np
.
asarray
(
np
.
dot
(
x
,
y
.
astype
(
dot_dtype
)))
return
np
.
asarray
(
np
.
dot
(
x
,
y
.
astype
(
dot_dtype
)))
elif
x_dtype
!=
dot_dtype
and
y_dtype
==
dot_dtype
:
elif
x_dtype
!=
dot_dtype
and
y_dtype
==
dot_dtype
:
@numba_njit
@numba_
basic.numba_
njit
def
dot
(
x
,
y
):
def
dot
(
x
,
y
):
return
np
.
asarray
(
np
.
dot
(
x
.
astype
(
dot_dtype
),
y
))
return
np
.
asarray
(
np
.
dot
(
x
.
astype
(
dot_dtype
),
y
))
else
:
else
:
@numba_
njit
()
@numba_
basic.numba_njit
def
dot
(
x
,
y
):
def
dot
(
x
,
y
):
return
np
.
asarray
(
np
.
dot
(
x
.
astype
(
dot_dtype
),
y
.
astype
(
dot_dtype
)))
return
np
.
asarray
(
np
.
dot
(
x
.
astype
(
dot_dtype
),
y
.
astype
(
dot_dtype
)))
...
@@ -642,7 +642,7 @@ def numba_funcify_Dot(op, node, **kwargs):
...
@@ -642,7 +642,7 @@ def numba_funcify_Dot(op, node, **kwargs):
else
:
else
:
@numba_njit
@numba_
basic.numba_
njit
def
dot_with_cast
(
x
,
y
):
def
dot_with_cast
(
x
,
y
):
return
dot
(
x
,
y
)
.
astype
(
out_dtype
)
return
dot
(
x
,
y
)
.
astype
(
out_dtype
)
...
@@ -653,7 +653,7 @@ def numba_funcify_Dot(op, node, **kwargs):
...
@@ -653,7 +653,7 @@ def numba_funcify_Dot(op, node, **kwargs):
def
numba_funcify_BatchedDot
(
op
,
node
,
**
kwargs
):
def
numba_funcify_BatchedDot
(
op
,
node
,
**
kwargs
):
dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
@numba_njit
@numba_
basic.numba_
njit
def
batched_dot
(
x
,
y
):
def
batched_dot
(
x
,
y
):
# Numba does not support 3D matmul
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
# https://github.com/numba/numba/issues/3804
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
浏览文件 @
5fbf81df
...
@@ -2,16 +2,16 @@ from collections.abc import Callable
...
@@ -2,16 +2,16 @@ from collections.abc import Callable
from
typing
import
Literal
from
typing
import
Literal
import
numpy
as
np
import
numpy
as
np
from
numba
import
njit
as
numba_njit
from
numba.core.extending
import
overload
from
numba.core.extending
import
overload
from
numba.np.linalg
import
ensure_lapack
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
scipy
import
linalg
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_scipy_linalg_matrix
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_scipy_linalg_matrix
@numba_njit
@numba_
basic.numba_
njit
def
_pivot_to_permutation
(
p
,
dtype
):
def
_pivot_to_permutation
(
p
,
dtype
):
p_inv
=
np
.
arange
(
len
(
p
))
.
astype
(
dtype
)
p_inv
=
np
.
arange
(
len
(
p
))
.
astype
(
dtype
)
for
i
in
range
(
len
(
p
)):
for
i
in
range
(
len
(
p
)):
...
@@ -19,7 +19,7 @@ def _pivot_to_permutation(p, dtype):
...
@@ -19,7 +19,7 @@ def _pivot_to_permutation(p, dtype):
return
p_inv
return
p_inv
@numba_njit
@numba_
basic.numba_
njit
def
_lu_factor_to_lu
(
a
,
dtype
,
overwrite_a
):
def
_lu_factor_to_lu
(
a
,
dtype
,
overwrite_a
):
A_copy
,
IPIV
,
_INFO
=
_getrf
(
a
,
overwrite_a
=
overwrite_a
)
A_copy
,
IPIV
,
_INFO
=
_getrf
(
a
,
overwrite_a
=
overwrite_a
)
...
...
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
浏览文件 @
5fbf81df
...
@@ -6,8 +6,8 @@ from numba.np.linalg import ensure_lapack
...
@@ -6,8 +6,8 @@ from numba.np.linalg import ensure_lapack
from
numpy
import
ndarray
from
numpy
import
ndarray
from
scipy
import
linalg
from
scipy
import
linalg
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
numba_njit
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_LAPACK
,
_get_underlying_float
,
_get_underlying_float
,
...
@@ -27,7 +27,7 @@ from pytensor.tensor._linalg.solve.tridiagonal import (
...
@@ -27,7 +27,7 @@ from pytensor.tensor._linalg.solve.tridiagonal import (
)
)
@numba_njit
@numba_
basic.numba_
njit
def
tridiagonal_norm
(
du
,
d
,
dl
):
def
tridiagonal_norm
(
du
,
d
,
dl
):
# Adapted from scipy _matrix_norm_tridiagonal:
# Adapted from scipy _matrix_norm_tridiagonal:
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
...
@@ -346,7 +346,7 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
...
@@ -346,7 +346,7 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
overwrite_d
=
op
.
overwrite_d
overwrite_d
=
op
.
overwrite_d
overwrite_du
=
op
.
overwrite_du
overwrite_du
=
op
.
overwrite_du
@numba_njit
(
cache
=
False
)
@numba_
basic.numba_
njit
(
cache
=
False
)
def
lu_factor_tridiagonal
(
dl
,
d
,
du
):
def
lu_factor_tridiagonal
(
dl
,
d
,
du
):
dl
,
d
,
du
,
du2
,
ipiv
,
_
=
_gttrf
(
dl
,
d
,
du
,
du2
,
ipiv
,
_
=
_gttrf
(
dl
,
dl
,
...
@@ -368,7 +368,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
...
@@ -368,7 +368,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b
=
op
.
overwrite_b
overwrite_b
=
op
.
overwrite_b
transposed
=
op
.
transposed
transposed
=
op
.
transposed
@numba_njit
(
cache
=
False
)
@numba_
basic.numba_
njit
(
cache
=
False
)
def
solve_lu_factor_tridiagonal
(
dl
,
d
,
du
,
du2
,
ipiv
,
b
):
def
solve_lu_factor_tridiagonal
(
dl
,
d
,
du
,
du2
,
ipiv
,
b
):
x
,
_
=
_gttrs
(
x
,
_
=
_gttrs
(
dl
,
dl
,
...
...
pytensor/link/numba/dispatch/nlinalg.py
浏览文件 @
5fbf81df
...
@@ -30,14 +30,14 @@ def numba_funcify_SVD(op, node, **kwargs):
...
@@ -30,14 +30,14 @@ def numba_funcify_SVD(op, node, **kwargs):
if
not
compute_uv
:
if
not
compute_uv
:
@numba_basic.numba_njit
()
@numba_basic.numba_njit
def
svd
(
x
):
def
svd
(
x
):
_
,
ret
,
_
=
np
.
linalg
.
svd
(
inputs_cast
(
x
),
full_matrices
)
_
,
ret
,
_
=
np
.
linalg
.
svd
(
inputs_cast
(
x
),
full_matrices
)
return
ret
return
ret
else
:
else
:
@numba_basic.numba_njit
()
@numba_basic.numba_njit
def
svd
(
x
):
def
svd
(
x
):
return
np
.
linalg
.
svd
(
inputs_cast
(
x
),
full_matrices
)
return
np
.
linalg
.
svd
(
inputs_cast
(
x
),
full_matrices
)
...
...
pytensor/link/numba/dispatch/random.py
浏览文件 @
5fbf81df
...
@@ -91,7 +91,7 @@ def numba_core_rv_default(op, node):
...
@@ -91,7 +91,7 @@ def numba_core_rv_default(op, node):
def
numba_core_BernoulliRV
(
op
,
node
):
def
numba_core_BernoulliRV
(
op
,
node
):
out_dtype
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
out_dtype
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
@numba_basic.numba_njit
()
@numba_basic.numba_njit
def
random
(
rng
,
p
):
def
random
(
rng
,
p
):
return
(
return
(
direct_cast
(
0
,
out_dtype
)
direct_cast
(
0
,
out_dtype
)
...
...
pytensor/link/numba/dispatch/shape.py
浏览文件 @
5fbf81df
...
@@ -3,6 +3,7 @@ from textwrap import dedent
...
@@ -3,6 +3,7 @@ from textwrap import dedent
import
numpy
as
np
import
numpy
as
np
from
numba.np.unsafe
import
ndarray
as
numba_ndarray
from
numba.np.unsafe
import
ndarray
as
numba_ndarray
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
create_arg_string
,
numba_njit
from
pytensor.link.numba.dispatch.basic
import
create_arg_string
,
numba_njit
from
pytensor.link.utils
import
compile_function_src
from
pytensor.link.utils
import
compile_function_src
...
@@ -12,7 +13,7 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
...
@@ -12,7 +13,7 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@numba_funcify.register
(
Shape
)
@numba_funcify.register
(
Shape
)
def
numba_funcify_Shape
(
op
,
**
kwargs
):
def
numba_funcify_Shape
(
op
,
**
kwargs
):
@numba_njit
@numba_
basic.numba_
njit
def
shape
(
x
):
def
shape
(
x
):
return
np
.
asarray
(
np
.
shape
(
x
))
return
np
.
asarray
(
np
.
shape
(
x
))
...
@@ -23,7 +24,7 @@ def numba_funcify_Shape(op, **kwargs):
...
@@ -23,7 +24,7 @@ def numba_funcify_Shape(op, **kwargs):
def
numba_funcify_Shape_i
(
op
,
**
kwargs
):
def
numba_funcify_Shape_i
(
op
,
**
kwargs
):
i
=
op
.
i
i
=
op
.
i
@numba_njit
@numba_
basic.numba_
njit
def
shape_i
(
x
):
def
shape_i
(
x
):
return
np
.
asarray
(
np
.
shape
(
x
)[
i
])
return
np
.
asarray
(
np
.
shape
(
x
)[
i
])
...
@@ -61,13 +62,13 @@ def numba_funcify_Reshape(op, **kwargs):
...
@@ -61,13 +62,13 @@ def numba_funcify_Reshape(op, **kwargs):
if
ndim
==
0
:
if
ndim
==
0
:
@numba_njit
@numba_
basic.numba_
njit
def
reshape
(
x
,
shape
):
def
reshape
(
x
,
shape
):
return
np
.
asarray
(
x
.
item
())
return
np
.
asarray
(
x
.
item
())
else
:
else
:
@numba_njit
@numba_
basic.numba_
njit
def
reshape
(
x
,
shape
):
def
reshape
(
x
,
shape
):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return
np
.
reshape
(
return
np
.
reshape
(
...
...
pytensor/link/numba/dispatch/signal/conv.py
浏览文件 @
5fbf81df
import
numpy
as
np
import
numpy
as
np
from
numba.np.arraymath
import
_get_inner_prod
from
numba.np.arraymath
import
_get_inner_prod
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
numba_njit
from
pytensor.tensor.signal.conv
import
Convolve1d
from
pytensor.tensor.signal.conv
import
Convolve1d
...
@@ -13,7 +13,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
...
@@ -13,7 +13,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
innerprod
=
_get_inner_prod
(
a_dtype
,
b_dtype
)
innerprod
=
_get_inner_prod
(
a_dtype
,
b_dtype
)
@numba_njit
@numba_
basic.numba_
njit
def
valid_convolve1d
(
x
,
y
):
def
valid_convolve1d
(
x
,
y
):
nx
=
len
(
x
)
nx
=
len
(
x
)
ny
=
len
(
y
)
ny
=
len
(
y
)
...
@@ -30,7 +30,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
...
@@ -30,7 +30,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return
ret
return
ret
@numba_njit
@numba_
basic.numba_
njit
def
full_convolve1d
(
x
,
y
):
def
full_convolve1d
(
x
,
y
):
nx
=
len
(
x
)
nx
=
len
(
x
)
ny
=
len
(
y
)
ny
=
len
(
y
)
...
@@ -59,7 +59,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
...
@@ -59,7 +59,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return
ret
return
ret
@numba_njit
@numba_
basic.numba_
njit
def
convolve_1d
(
x
,
y
,
mode
):
def
convolve_1d
(
x
,
y
,
mode
):
if
mode
:
if
mode
:
return
full_convolve1d
(
x
,
y
)
return
full_convolve1d
(
x
,
y
)
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
5fbf81df
...
@@ -3,7 +3,8 @@ import warnings
...
@@ -3,7 +3,8 @@ import warnings
import
numpy
as
np
import
numpy
as
np
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
_lu_1
,
_lu_1
,
...
@@ -63,7 +64,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
...
@@ -63,7 +64,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_njit
@numba_
basic.numba_
njit
def
cholesky
(
a
):
def
cholesky
(
a
):
if
check_finite
:
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
@@ -95,7 +96,7 @@ def pivot_to_permutation(op, node, **kwargs):
...
@@ -95,7 +96,7 @@ def pivot_to_permutation(op, node, **kwargs):
inverse
=
op
.
inverse
inverse
=
op
.
inverse
dtype
=
node
.
outputs
[
0
]
.
dtype
dtype
=
node
.
outputs
[
0
]
.
dtype
@numba_njit
@numba_
basic.numba_
njit
def
numba_pivot_to_permutation
(
piv
):
def
numba_pivot_to_permutation
(
piv
):
p_inv
=
_pivot_to_permutation
(
piv
,
dtype
)
p_inv
=
_pivot_to_permutation
(
piv
,
dtype
)
...
@@ -118,7 +119,7 @@ def numba_funcify_LU(op, node, **kwargs):
...
@@ -118,7 +119,7 @@ def numba_funcify_LU(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
if
dtype
in
complex_dtypes
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_njit
(
inline
=
"always"
)
@numba_
basic.numba_
njit
(
inline
=
"always"
)
def
lu
(
a
):
def
lu
(
a
):
if
check_finite
:
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
@@ -165,7 +166,7 @@ def numba_funcify_LUFactor(op, node, **kwargs):
...
@@ -165,7 +166,7 @@ def numba_funcify_LUFactor(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
if
dtype
in
complex_dtypes
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_njit
@numba_
basic.numba_
njit
def
lu_factor
(
a
):
def
lu_factor
(
a
):
if
check_finite
:
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
@@ -185,7 +186,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
...
@@ -185,7 +186,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype
=
node
.
outputs
[
0
]
.
dtype
dtype
=
node
.
outputs
[
0
]
.
dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_njit
@numba_
basic.numba_
njit
def
block_diag
(
*
arrs
):
def
block_diag
(
*
arrs
):
shapes
=
np
.
array
([
a
.
shape
for
a
in
arrs
],
dtype
=
"int"
)
shapes
=
np
.
array
([
a
.
shape
for
a
in
arrs
],
dtype
=
"int"
)
out_shape
=
[
int
(
s
)
for
s
in
np
.
sum
(
shapes
,
axis
=
0
)]
out_shape
=
[
int
(
s
)
for
s
in
np
.
sum
(
shapes
,
axis
=
0
)]
...
@@ -235,7 +236,7 @@ def numba_funcify_Solve(op, node, **kwargs):
...
@@ -235,7 +236,7 @@ def numba_funcify_Solve(op, node, **kwargs):
)
)
solve_fn
=
_solve_gen
solve_fn
=
_solve_gen
@numba_njit
@numba_
basic.numba_
njit
def
solve
(
a
,
b
):
def
solve
(
a
,
b
):
if
check_finite
:
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
@@ -267,7 +268,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
...
@@ -267,7 +268,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
"Solve Triangular"
)
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
"Solve Triangular"
)
)
)
@numba_njit
@numba_
basic.numba_
njit
def
solve_triangular
(
a
,
b
):
def
solve_triangular
(
a
,
b
):
if
check_finite
:
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
@@ -304,7 +305,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
...
@@ -304,7 +305,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_njit
@numba_
basic.numba_
njit
def
cho_solve
(
c
,
b
):
def
cho_solve
(
c
,
b
):
if
check_finite
:
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
c
),
np
.
isnan
(
c
))):
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
c
),
np
.
isnan
(
c
))):
...
@@ -337,7 +338,7 @@ def numba_funcify_QR(op, node, **kwargs):
...
@@ -337,7 +338,7 @@ def numba_funcify_QR(op, node, **kwargs):
integer_input
=
dtype
in
integer_dtypes
integer_input
=
dtype
in
integer_dtypes
in_dtype
=
config
.
floatX
if
integer_input
else
dtype
in_dtype
=
config
.
floatX
if
integer_input
else
dtype
@numba_njit
(
cache
=
False
)
@numba_
basic.numba_
njit
(
cache
=
False
)
def
qr
(
a
):
def
qr
(
a
):
if
check_finite
:
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
pytensor/link/numba/dispatch/sort.py
浏览文件 @
5fbf81df
...
@@ -2,8 +2,8 @@ import warnings
...
@@ -2,8 +2,8 @@ import warnings
import
numpy
as
np
import
numpy
as
np
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
numba_njit
from
pytensor.tensor.sort
import
ArgSortOp
,
SortOp
from
pytensor.tensor.sort
import
ArgSortOp
,
SortOp
...
@@ -18,7 +18,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
...
@@ -18,7 +18,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
UserWarning
,
UserWarning
,
)
)
@numba_njit
@numba_
basic.numba_
njit
def
sort_f
(
a
,
axis
):
def
sort_f
(
a
,
axis
):
axis
=
axis
.
item
()
axis
=
axis
.
item
()
...
@@ -45,7 +45,7 @@ def numba_funcify_ArgSortOp(op, node, **kwargs):
...
@@ -45,7 +45,7 @@ def numba_funcify_ArgSortOp(op, node, **kwargs):
UserWarning
,
UserWarning
,
)
)
@numba_njit
@numba_
basic.numba_
njit
def
argort_f
(
X
,
axis
):
def
argort_f
(
X
,
axis
):
axis
=
axis
.
item
()
axis
=
axis
.
item
()
...
...
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
5fbf81df
...
@@ -8,6 +8,7 @@ from numba import types
...
@@ -8,6 +8,7 @@ from numba import types
from
numba.core.pythonapi
import
box
from
numba.core.pythonapi
import
box
from
pytensor.graph
import
Type
from
pytensor.graph
import
Type
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
generate_fallback_impl
,
numba_njit
from
pytensor.link.numba.dispatch.basic
import
generate_fallback_impl
,
numba_njit
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
...
@@ -99,7 +100,7 @@ enable_slice_boxing()
...
@@ -99,7 +100,7 @@ enable_slice_boxing()
@numba_funcify.register
(
MakeSlice
)
@numba_funcify.register
(
MakeSlice
)
def
numba_funcify_MakeSlice
(
op
,
**
kwargs
):
def
numba_funcify_MakeSlice
(
op
,
**
kwargs
):
@numba_njit
@numba_
basic.numba_
njit
def
makeslice
(
*
x
):
def
makeslice
(
*
x
):
return
slice
(
*
x
)
return
slice
(
*
x
)
...
@@ -297,7 +298,7 @@ def numba_funcify_multiple_integer_vector_indexing(
...
@@ -297,7 +298,7 @@ def numba_funcify_multiple_integer_vector_indexing(
if
isinstance
(
op
,
AdvancedSubtensor
):
if
isinstance
(
op
,
AdvancedSubtensor
):
@numba_njit
@numba_
basic.numba_
njit
def
advanced_subtensor_multiple_vector
(
x
,
*
idxs
):
def
advanced_subtensor_multiple_vector
(
x
,
*
idxs
):
none_slices
=
idxs
[:
first_axis
]
none_slices
=
idxs
[:
first_axis
]
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
...
@@ -328,7 +329,7 @@ def numba_funcify_multiple_integer_vector_indexing(
...
@@ -328,7 +329,7 @@ def numba_funcify_multiple_integer_vector_indexing(
if
op
.
set_instead_of_inc
:
if
op
.
set_instead_of_inc
:
@numba_njit
@numba_
basic.numba_
njit
def
advanced_set_subtensor_multiple_vector
(
x
,
y
,
*
idxs
):
def
advanced_set_subtensor_multiple_vector
(
x
,
y
,
*
idxs
):
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
x_shape
=
x
.
shape
x_shape
=
x
.
shape
...
@@ -350,7 +351,7 @@ def numba_funcify_multiple_integer_vector_indexing(
...
@@ -350,7 +351,7 @@ def numba_funcify_multiple_integer_vector_indexing(
else
:
else
:
@numba_njit
@numba_
basic.numba_
njit
def
advanced_inc_subtensor_multiple_vector
(
x
,
y
,
*
idxs
):
def
advanced_inc_subtensor_multiple_vector
(
x
,
y
,
*
idxs
):
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
x_shape
=
x
.
shape
x_shape
=
x
.
shape
...
@@ -382,7 +383,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
...
@@ -382,7 +383,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if
set_instead_of_inc
:
if
set_instead_of_inc
:
if
broadcast_with_index
:
if
broadcast_with_index
:
@numba_njit
(
boundscheck
=
True
)
@numba_
basic.numba_
njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
val
,
idxs
):
def
advancedincsubtensor1_inplace
(
x
,
val
,
idxs
):
if
val
.
ndim
==
x
.
ndim
:
if
val
.
ndim
==
x
.
ndim
:
core_val
=
val
[
0
]
core_val
=
val
[
0
]
...
@@ -398,7 +399,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
...
@@ -398,7 +399,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else
:
else
:
@numba_njit
(
boundscheck
=
True
)
@numba_
basic.numba_
njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
):
def
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
):
if
not
len
(
idxs
)
==
len
(
vals
):
if
not
len
(
idxs
)
==
len
(
vals
):
raise
ValueError
(
"The number of indices and values must match."
)
raise
ValueError
(
"The number of indices and values must match."
)
...
@@ -409,7 +410,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
...
@@ -409,7 +410,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else
:
else
:
if
broadcast_with_index
:
if
broadcast_with_index
:
@numba_njit
(
boundscheck
=
True
)
@numba_
basic.numba_
njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
val
,
idxs
):
def
advancedincsubtensor1_inplace
(
x
,
val
,
idxs
):
if
val
.
ndim
==
x
.
ndim
:
if
val
.
ndim
==
x
.
ndim
:
core_val
=
val
[
0
]
core_val
=
val
[
0
]
...
@@ -425,7 +426,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
...
@@ -425,7 +426,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else
:
else
:
@numba_njit
(
boundscheck
=
True
)
@numba_
basic.numba_
njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
):
def
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
):
if
not
len
(
idxs
)
==
len
(
vals
):
if
not
len
(
idxs
)
==
len
(
vals
):
raise
ValueError
(
"The number of indices and values must match."
)
raise
ValueError
(
"The number of indices and values must match."
)
...
@@ -440,7 +441,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
...
@@ -440,7 +441,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
else
:
else
:
@numba_njit
@numba_
basic.numba_
njit
def
advancedincsubtensor1
(
x
,
vals
,
idxs
):
def
advancedincsubtensor1
(
x
,
vals
,
idxs
):
x
=
x
.
copy
()
x
=
x
.
copy
()
return
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
)
return
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
)
...
...
pytensor/link/numba/dispatch/tensor_basic.py
浏览文件 @
5fbf81df
...
@@ -6,7 +6,6 @@ from pytensor.link.numba.dispatch import basic as numba_basic
...
@@ -6,7 +6,6 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
from
pytensor.link.numba.dispatch.basic
import
(
create_tuple_string
,
create_tuple_string
,
numba_funcify
,
numba_funcify
,
numba_njit
,
)
)
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
...
@@ -243,7 +242,7 @@ def numba_funcify_ScalarFromTensor(op, **kwargs):
...
@@ -243,7 +242,7 @@ def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba_funcify.register
(
Nonzero
)
@numba_funcify.register
(
Nonzero
)
def
numba_funcify_Nonzero
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Nonzero
(
op
,
node
,
**
kwargs
):
@numba_njit
@numba_
basic.numba_
njit
def
nonzero
(
a
):
def
nonzero
(
a
):
result_tuple
=
np
.
nonzero
(
a
)
result_tuple
=
np
.
nonzero
(
a
)
if
a
.
ndim
==
1
:
if
a
.
ndim
==
1
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论