Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ffd999c8
提交
ffd999c8
authored
12月 01, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
12月 08, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Numba linalg: handle dtypes more strictly
上级
edb1b205
显示空白字符变更
内嵌
并排
正在显示
13 个修改的文件
包含
284 行增加
和
165 行删除
+284
-165
cholesky.py
...nsor/link/numba/dispatch/linalg/decomposition/cholesky.py
+5
-5
lu.py
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
+5
-4
lu_factor.py
...sor/link/numba/dispatch/linalg/decomposition/lu_factor.py
+5
-8
cholesky.py
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
+8
-7
general.py
pytensor/link/numba/dispatch/linalg/solve/general.py
+11
-10
lu_solve.py
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
+13
-9
norm.py
pytensor/link/numba/dispatch/linalg/solve/norm.py
+4
-7
posdef.py
pytensor/link/numba/dispatch/linalg/solve/posdef.py
+16
-15
symmetric.py
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
+20
-19
triangular.py
pytensor/link/numba/dispatch/linalg/solve/triangular.py
+8
-7
tridiagonal.py
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
+78
-35
utils.py
pytensor/link/numba/dispatch/linalg/utils.py
+24
-14
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+87
-25
没有找到文件。
pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py
浏览文件 @
ffd999c8
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numba.types
import
Float
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_
scipy_
linalg_matrix
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
...
...
@@ -24,9 +24,9 @@ def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
@overload
(
_cholesky
)
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
,
check_finite
=
True
):
ensure_lapack
()
_check_
scipy_linalg_matrix
(
A
,
"cholesky"
)
_check_
linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cholesky"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_potrf
=
_LAPACK
()
.
numba_xpotrf
(
dtype
)
def
impl
(
A
,
lower
=
0
,
overwrite_a
=
False
,
check_finite
=
True
):
...
...
@@ -47,7 +47,7 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
numba_potrf
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
A_copy
.
ctypes
,
LDA
,
INFO
,
)
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
浏览文件 @
ffd999c8
...
...
@@ -3,12 +3,13 @@ from typing import Literal
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_
scipy_
linalg_matrix
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
@numba_basic.numba_njit
...
...
@@ -116,7 +117,7 @@ def lu_impl_1(
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
"""
ensure_lapack
()
_check_
scipy_linalg_matrix
(
a
,
"lu"
)
_check_
linalg_matrix
(
a
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu"
)
dtype
=
a
.
dtype
def
impl
(
...
...
@@ -146,7 +147,7 @@ def lu_impl_2(
"""
ensure_lapack
()
_check_
scipy_linalg_matrix
(
a
,
"lu"
)
_check_
linalg_matrix
(
a
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu"
)
dtype
=
a
.
dtype
def
impl
(
...
...
@@ -179,7 +180,7 @@ def lu_impl_3(
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
"""
ensure_lapack
()
_check_
scipy_linalg_matrix
(
a
,
"lu"
)
_check_
linalg_matrix
(
a
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu"
)
dtype
=
a
.
dtype
def
impl
(
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
浏览文件 @
ffd999c8
...
...
@@ -3,18 +3,16 @@ from typing import cast as typing_cast
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_getrf
(
A
,
overwrite_a
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
...
...
@@ -38,9 +36,8 @@ def getrf_impl(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
Callable
[[
np
.
ndarray
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_
scipy_linalg_matrix
(
A
,
"getrf"
)
_check_
linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"getrf"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrf
=
_LAPACK
()
.
numba_xgetrf
(
dtype
)
def
impl
(
...
...
@@ -59,7 +56,7 @@ def getrf_impl(
IPIV
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
# type: ignore
INFO
=
val_to_int_ptr
(
0
)
numba_getrf
(
M
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
INFO
)
numba_getrf
(
M
,
N
,
A_copy
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
INFO
)
return
A_copy
,
IPIV
,
int_ptr_to_val
(
INFO
)
...
...
@@ -79,7 +76,7 @@ def lu_factor_impl(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
Callable
[[
np
.
ndarray
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
ensure_lapack
()
_check_
scipy_linalg_matrix
(
A
,
"lu_factor"
)
_check_
linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_factor"
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
A_copy
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
...
...
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
浏览文件 @
ffd999c8
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
...
...
@@ -31,10 +32,10 @@ def _cho_solve(
@overload
(
_cho_solve
)
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
ensure_lapack
()
_check_scipy_linalg_matrix
(
C
,
"cho_solve"
)
_check_scipy_linalg_matrix
(
B
,
"cho_solve"
)
_check_linalg_matrix
(
C
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cho_solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"cho_solve"
)
_check_dtypes_match
((
C
,
B
),
func_name
=
"cho_solve"
)
dtype
=
C
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_potrs
=
_LAPACK
()
.
numba_xpotrs
(
dtype
)
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
...
...
@@ -71,9 +72,9 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
UPLO
,
N
,
NRHS
,
C_f
.
view
(
w_type
)
.
ctypes
,
C_f
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
B_copy
.
ctypes
,
LDB
,
INFO
,
)
...
...
pytensor/link/numba/dispatch/linalg/solve/general.py
浏览文件 @
ffd999c8
...
...
@@ -2,12 +2,12 @@ from collections.abc import Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
...
...
@@ -16,7 +16,8 @@ from pytensor.link.numba.dispatch.linalg.solve.lu_solve import _getrs
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_check_dtypes_match
,
_check_linalg_matrix
,
_solve_check
,
)
...
...
@@ -37,9 +38,8 @@ def xgecon_impl(
Compute the condition number of a matrix A.
"""
ensure_lapack
()
_check_
scipy_linalg_matrix
(
A
,
"gecon"
)
_check_
linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"gecon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_gecon
=
_LAPACK
()
.
numba_xgecon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
...
...
@@ -58,11 +58,11 @@ def xgecon_impl(
numba_gecon
(
NORM
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
A_copy
.
ctypes
,
LDA
,
A_NORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
A_NORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
...
...
@@ -106,8 +106,9 @@ def solve_gen_impl(
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_dtypes_match
((
A
,
B
),
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
...
...
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
浏览文件 @
ffd999c8
...
...
@@ -3,18 +3,19 @@ from typing import Literal, TypeAlias
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
,
int32
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
...
...
@@ -44,10 +45,11 @@ def getrs_impl(
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
],
tuple
[
np
.
ndarray
,
int
]
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
LU
,
"getrs"
)
_check_scipy_linalg_matrix
(
B
,
"getrs"
)
_check_linalg_matrix
(
LU
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"getrs"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"getrs"
)
_check_dtypes_match
((
LU
,
B
),
func_name
=
"getrs"
)
_check_linalg_matrix
(
IPIV
,
ndim
=
1
,
dtype
=
int32
,
func_name
=
"getrs"
)
dtype
=
LU
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrs
=
_LAPACK
()
.
numba_xgetrs
(
dtype
)
def
impl
(
...
...
@@ -84,10 +86,10 @@ def getrs_impl(
TRANS
,
N
,
NRHS
,
LU
.
view
(
w_type
)
.
ctypes
,
LU
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
B_copy
.
ctypes
,
LDB
,
INFO
,
)
...
...
@@ -124,8 +126,10 @@ def lu_solve_impl(
check_finite
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
lu_and_piv
[
0
],
"lu_solve"
)
_check_scipy_linalg_matrix
(
b
,
"lu_solve"
)
lu
,
_piv
=
lu_and_piv
_check_linalg_matrix
(
lu
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_solve"
)
_check_linalg_matrix
(
b
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"lu_solve"
)
_check_dtypes_match
((
lu
,
b
),
func_name
=
"lu_solve"
)
def
impl
(
lu
:
np
.
ndarray
,
...
...
pytensor/link/numba/dispatch/linalg/solve/norm.py
浏览文件 @
ffd999c8
...
...
@@ -2,14 +2,14 @@ from collections.abc import Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_
scipy_
linalg_matrix
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_xlange
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
float
:
...
...
@@ -28,9 +28,8 @@ def xlange_impl(
largest absolute value of a matrix A.
"""
ensure_lapack
()
_check_
scipy_linalg_matrix
(
A
,
"norm"
)
_check_
linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"norm"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_lange
=
_LAPACK
()
.
numba_xlange
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
):
...
...
@@ -49,9 +48,7 @@ def xlange_impl(
)
WORK
=
np
.
empty
(
_M
,
dtype
=
dtype
)
# type: ignore
result
=
numba_lange
(
NORM
,
M
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
WORK
.
view
(
w_type
)
.
ctypes
)
result
=
numba_lange
(
NORM
,
M
,
N
,
A_copy
.
ctypes
,
LDA
,
WORK
.
ctypes
)
return
result
...
...
pytensor/link/numba/dispatch/linalg/solve/posdef.py
浏览文件 @
ffd999c8
...
...
@@ -2,19 +2,20 @@ from collections.abc import Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
...
...
@@ -49,10 +50,10 @@ def posv_impl(
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
],
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_dtypes_match
((
A
,
B
),
func_name
=
"solve"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_posv
=
_LAPACK
()
.
numba_xposv
(
dtype
)
def
impl
(
...
...
@@ -99,9 +100,9 @@ def posv_impl(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
A_copy
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
B_copy
.
ctypes
,
LDB
,
INFO
,
)
...
...
@@ -127,9 +128,8 @@ def pocon_impl(
A
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_
scipy_linalg_matrix
(
A
,
"pocon"
)
_check_
linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"pocon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_pocon
=
_LAPACK
()
.
numba_xpocon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
anorm
:
float
):
...
...
@@ -148,11 +148,11 @@ def pocon_impl(
numba_pocon
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
A_copy
.
ctypes
,
LDA
,
ANORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
ANORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
...
...
@@ -196,8 +196,9 @@ def solve_psd_impl(
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_dtypes_match
((
A
,
B
),
func_name
=
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
...
...
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
浏览文件 @
ffd999c8
...
...
@@ -2,19 +2,20 @@ from collections.abc import Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
...
...
@@ -37,10 +38,10 @@ def sysv_impl(
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
],
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"sysv"
)
_check_scipy_linalg_matrix
(
B
,
"sysv"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"sysv"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"sysv"
)
_check_dtypes_match
((
A
,
B
),
func_name
=
"sysv"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_sysv
=
_LAPACK
()
.
numba_xsysv
(
dtype
)
def
impl
(
...
...
@@ -84,12 +85,12 @@ def sysv_impl(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
A_copy
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
B_copy
.
ctypes
,
LDB
,
WORK
.
view
(
w_type
)
.
ctypes
,
WORK
.
ctypes
,
LWORK
,
INFO
,
)
...
...
@@ -103,12 +104,12 @@ def sysv_impl(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
A_copy
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
B_copy
.
ctypes
,
LDB
,
WORK
.
view
(
w_type
)
.
ctypes
,
WORK
.
ctypes
,
LWORK
,
INFO
,
)
...
...
@@ -133,9 +134,8 @@ def sycon_impl(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_
scipy_linalg_matrix
(
A
,
"sycon"
)
_check_
linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"sycon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_sycon
=
_LAPACK
()
.
numba_xsycon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
...
...
@@ -154,12 +154,12 @@ def sycon_impl(
numba_sycon
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
A_copy
.
ctypes
,
LDA
,
ipiv
.
ctypes
,
ANORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
ANORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
...
...
@@ -203,8 +203,9 @@ def solve_symmetric_impl(
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_dtypes_match
((
A
,
B
),
func_name
=
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
...
...
pytensor/link/numba/dispatch/linalg/solve/triangular.py
浏览文件 @
ffd999c8
import
numpy
as
np
from
numba.core
import
types
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
...
...
@@ -45,10 +46,10 @@ def _solve_triangular(
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve_triangular"
)
_check_scipy_linalg_matrix
(
B
,
"solve_triangular"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve_triangular"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve_triangular"
)
_check_dtypes_match
((
A
,
B
),
func_name
=
"solve_triangular"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_trtrs
=
_LAPACK
()
.
numba_xtrtrs
(
dtype
)
if
isinstance
(
dtype
,
types
.
Complex
):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
...
...
@@ -99,9 +100,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
DIAG
,
N
,
NRHS
,
A_f
.
view
(
w_type
)
.
ctypes
,
A_f
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
B_copy
.
ctypes
,
LDB
,
INFO
,
)
...
...
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
浏览文件 @
ffd999c8
...
...
@@ -2,21 +2,24 @@ from collections.abc import Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
,
int32
from
numba.np.linalg
import
ensure_lapack
from
numpy
import
ndarray
from
scipy
import
linalg
from
pytensor
import
config
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
generate_fallback_impl
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
...
...
@@ -63,11 +66,11 @@ def gttrf_impl(
tuple
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
int
],
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
dl
,
"gttrf"
)
_check_scipy_linalg_matrix
(
d
,
"gttrf"
)
_check_scipy_linalg_matrix
(
du
,
"gttrf"
)
_check_linalg_matrix
(
dl
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gttrf"
)
_check_linalg_matrix
(
d
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gttrf"
)
_check_linalg_matrix
(
du
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gttrf"
)
_check_dtypes_match
((
dl
,
d
,
du
),
func_name
=
"gttrf"
)
dtype
=
d
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_gttrf
=
_LAPACK
()
.
numba_xgttrf
(
dtype
)
def
impl
(
...
...
@@ -94,10 +97,10 @@ def gttrf_impl(
numba_gttrf
(
val_to_int_ptr
(
n
),
dl
.
view
(
w_type
)
.
ctypes
,
d
.
view
(
w_type
)
.
ctypes
,
du
.
view
(
w_type
)
.
ctypes
,
du2
.
view
(
w_type
)
.
ctypes
,
dl
.
ctypes
,
d
.
ctypes
,
du
.
ctypes
,
du2
.
ctypes
,
ipiv
.
ctypes
,
info
,
)
...
...
@@ -136,13 +139,14 @@ def gttrs_impl(
tuple
[
ndarray
,
int
],
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
dl
,
"gttrs"
)
_check_scipy_linalg_matrix
(
d
,
"gttrs"
)
_check_scipy_linalg_matrix
(
du
,
"gttrs"
)
_check_scipy_linalg_matrix
(
du2
,
"gttrs"
)
_check_scipy_linalg_matrix
(
b
,
"gttrs"
)
_check_linalg_matrix
(
dl
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gttrs"
)
_check_linalg_matrix
(
d
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gttrs"
)
_check_linalg_matrix
(
du
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gttrs"
)
_check_linalg_matrix
(
du2
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gttrs"
)
_check_linalg_matrix
(
b
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"gttrs"
)
_check_dtypes_match
((
dl
,
d
,
du
,
du2
,
b
),
func_name
=
"gttrs"
)
_check_linalg_matrix
(
ipiv
,
ndim
=
1
,
dtype
=
int32
,
func_name
=
"gttrs"
)
dtype
=
d
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_gttrs
=
_LAPACK
()
.
numba_xgttrs
(
dtype
)
def
impl
(
...
...
@@ -181,12 +185,12 @@ def gttrs_impl(
val_to_int_ptr
(
_trans_char_to_int
(
trans
)),
val_to_int_ptr
(
n
),
val_to_int_ptr
(
nrhs
),
dl
.
view
(
w_type
)
.
ctypes
,
d
.
view
(
w_type
)
.
ctypes
,
du
.
view
(
w_type
)
.
ctypes
,
du2
.
view
(
w_type
)
.
ctypes
,
dl
.
ctypes
,
d
.
ctypes
,
du
.
ctypes
,
du2
.
ctypes
,
ipiv
.
ctypes
,
b
.
view
(
w_type
)
.
ctypes
,
b
.
ctypes
,
val_to_int_ptr
(
n
),
info
,
)
...
...
@@ -222,12 +226,13 @@ def gtcon_impl(
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
float
,
str
],
tuple
[
ndarray
,
int
]
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
dl
,
"gtcon"
)
_check_scipy_linalg_matrix
(
d
,
"gtcon"
)
_check_scipy_linalg_matrix
(
du
,
"gtcon"
)
_check_scipy_linalg_matrix
(
du2
,
"gtcon"
)
_check_linalg_matrix
(
dl
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
d
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
du
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
du2
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_dtypes_match
((
dl
,
d
,
du
,
du2
),
func_name
=
"gtcon"
)
_check_linalg_matrix
(
ipiv
,
ndim
=
1
,
dtype
=
int32
,
func_name
=
"gtcon"
)
dtype
=
d
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_gtcon
=
_LAPACK
()
.
numba_xgtcon
(
dtype
)
def
impl
(
...
...
@@ -248,14 +253,14 @@ def gtcon_impl(
numba_gtcon
(
val_to_int_ptr
(
ord
(
norm
)),
val_to_int_ptr
(
n
),
dl
.
view
(
w_type
)
.
ctypes
,
d
.
view
(
w_type
)
.
ctypes
,
du
.
view
(
w_type
)
.
ctypes
,
du2
.
view
(
w_type
)
.
ctypes
,
dl
.
ctypes
,
d
.
ctypes
,
du
.
ctypes
,
du2
.
ctypes
,
ipiv
.
ctypes
,
np
.
array
(
anorm
,
dtype
=
dtype
)
.
view
(
w_type
)
.
ctypes
,
rcond
.
view
(
w_type
)
.
ctypes
,
work
.
view
(
w_type
)
.
ctypes
,
np
.
array
(
anorm
,
dtype
=
dtype
)
.
ctypes
,
rcond
.
ctypes
,
work
.
ctypes
,
iwork
.
ctypes
,
info
,
)
...
...
@@ -300,8 +305,9 @@ def _tridiagonal_solve_impl(
transposed
:
bool
,
)
->
Callable
[[
ndarray
,
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_dtypes_match
((
A
,
B
),
func_name
=
"solve"
)
def
impl
(
A
:
ndarray
,
...
...
@@ -342,12 +348,26 @@ def _tridiagonal_solve_impl(
@numba_funcify.register
(
LUFactorTridiagonal
)
def
numba_funcify_LUFactorTridiagonal
(
op
:
LUFactorTridiagonal
,
node
,
**
kwargs
):
if
any
(
i
.
type
.
numpy_dtype
.
kind
==
"c"
for
i
in
node
.
inputs
):
return
generate_fallback_impl
(
op
,
node
=
node
)
overwrite_dl
=
op
.
overwrite_dl
overwrite_d
=
op
.
overwrite_d
overwrite_du
=
op
.
overwrite_du
out_dtype
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
must_cast_inputs
=
tuple
(
inp
.
type
.
numpy_dtype
!=
out_dtype
for
inp
in
node
.
inputs
)
if
any
(
must_cast_inputs
)
and
config
.
compiler_verbose
:
print
(
"LUFactorTridiagonal requires casting at least one input"
)
# noqa: T201
@numba_basic.numba_njit
(
cache
=
False
)
def
lu_factor_tridiagonal
(
dl
,
d
,
du
):
if
must_cast_inputs
[
0
]:
d
=
d
.
astype
(
out_dtype
)
if
must_cast_inputs
[
1
]:
dl
=
dl
.
astype
(
out_dtype
)
if
must_cast_inputs
[
2
]:
du
=
du
.
astype
(
out_dtype
)
dl
,
d
,
du
,
du2
,
ipiv
,
_
=
_gttrf
(
dl
,
d
,
...
...
@@ -365,11 +385,34 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
def
numba_funcify_SolveLUFactorTridiagonal
(
op
:
SolveLUFactorTridiagonal
,
node
,
**
kwargs
):
if
any
(
i
.
type
.
numpy_dtype
.
kind
==
"c"
for
i
in
node
.
inputs
):
return
generate_fallback_impl
(
op
,
node
=
node
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
overwrite_b
=
op
.
overwrite_b
transposed
=
op
.
transposed
must_cast_inputs
=
tuple
(
inp
.
type
.
numpy_dtype
!=
(
np
.
int32
if
i
==
4
else
out_dtype
)
for
i
,
inp
in
enumerate
(
node
.
inputs
)
)
if
any
(
must_cast_inputs
)
and
config
.
compiler_verbose
:
print
(
"SolveLUFactorTridiagonal requires casting at least one input"
)
# noqa: T201
@numba_basic.numba_njit
(
cache
=
False
)
def
solve_lu_factor_tridiagonal
(
dl
,
d
,
du
,
du2
,
ipiv
,
b
):
if
must_cast_inputs
[
0
]:
dl
=
dl
.
astype
(
out_dtype
)
if
must_cast_inputs
[
1
]:
d
=
d
.
astype
(
out_dtype
)
if
must_cast_inputs
[
2
]:
du
=
du
.
astype
(
out_dtype
)
if
must_cast_inputs
[
3
]:
du2
=
du2
.
astype
(
out_dtype
)
if
must_cast_inputs
[
4
]:
ipiv
=
ipiv
.
astype
(
"int32"
)
if
must_cast_inputs
[
5
]:
b
=
b
.
astype
(
out_dtype
)
x
,
_
=
_gttrs
(
dl
,
d
,
...
...
pytensor/link/numba/dispatch/linalg/utils.py
浏览文件 @
ffd999c8
from
collections.abc
import
Callable
from
collections.abc
import
Callable
,
Sequence
import
numba
from
numba.core
import
types
...
...
@@ -32,24 +32,34 @@ def _trans_char_to_int(trans):
return
ord
(
"C"
)
def
_check_
scipy_linalg_matrix
(
a
,
func_name
):
def
_check_
linalg_matrix
(
a
,
*
,
ndim
:
int
|
Sequence
[
int
],
dtype
,
func_name
):
"""
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
"""
prefix
=
"scipy.linalg"
# Unpack optional type
if
isinstance
(
a
,
types
.
Optional
):
a
=
a
.
type
if
not
isinstance
(
a
,
types
.
Array
):
msg
=
f
"{
prefix}.{func_name}()
only supported for array types"
msg
=
f
"{
func_name}
only supported for array types"
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
a
.
ndim
not
in
[
1
,
2
]:
msg
=
(
f
"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}."
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
not
isinstance
(
a
.
dtype
,
types
.
Float
|
types
.
Complex
):
msg
=
f
"{prefix}.{func_name}() only supported on float and complex arrays."
ndim_msg
=
f
"{func_name} only supported on {ndim}d arrays, got {a.ndim}."
if
isinstance
(
ndim
,
int
):
if
a
.
ndim
!=
ndim
:
raise
numba
.
TypingError
(
ndim_msg
,
highlighting
=
False
)
elif
a
.
ndim
not
in
ndim
:
raise
numba
.
TypingError
(
ndim_msg
,
highlighting
=
False
)
dtype_msg
=
f
"{func_name} only supported for {dtype}, got {a.dtype}."
if
isinstance
(
dtype
,
type
|
tuple
):
if
not
isinstance
(
a
.
dtype
,
dtype
):
raise
numba
.
TypingError
(
dtype_msg
,
highlighting
=
False
)
elif
a
.
dtype
!=
dtype
:
raise
numba
.
TypingError
(
dtype_msg
,
highlighting
=
False
)
def
_check_dtypes_match
(
arrays
:
Sequence
,
func_name
=
"cho_solve"
):
dtypes
=
[
a
.
dtype
for
a
in
arrays
]
first_dtype
=
dtypes
[
0
]
for
other_dtype
in
dtypes
[
1
:]:
if
first_dtype
!=
other_dtype
:
msg
=
f
"{func_name} only supported for matching dtypes, got {dtypes}"
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
ffd999c8
...
...
@@ -63,13 +63,20 @@ def numba_funcify_Cholesky(op, node, **kwargs):
check_finite
=
op
.
check_finite
on_error
=
op
.
on_error
dtype
=
node
.
inputs
[
0
]
.
dtype
if
dtype
in
complex_dtypes
:
inp_dtype
=
node
.
inputs
[
0
]
.
type
.
numpy_
dtype
if
inp_dtype
.
kind
==
"c"
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
discrete_inp
=
inp_dtype
.
kind
in
"ibu"
if
discrete_inp
and
config
.
compiler_verbose
:
print
(
"Cholesky requires casting discrete input to float"
)
# noqa: T201
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
@numba_basic.numba_njit
def
cholesky
(
a
):
if
check_finite
:
if
discrete_inp
:
a
=
a
.
astype
(
out_dtype
)
elif
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to cholesky"
...
...
@@ -112,18 +119,24 @@ def pivot_to_permutation(op, node, **kwargs):
@numba_funcify.register
(
LU
)
def
numba_funcify_LU
(
op
,
node
,
**
kwargs
):
inp_dtype
=
node
.
inputs
[
0
]
.
type
.
numpy_dtype
if
inp_dtype
.
kind
==
"c"
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
discrete_inp
=
inp_dtype
.
kind
in
"ibu"
if
discrete_inp
and
config
.
compiler_verbose
:
print
(
"LU requires casting discrete input to float"
)
# noqa: T201
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
permute_l
=
op
.
permute_l
check_finite
=
op
.
check_finite
p_indices
=
op
.
p_indices
overwrite_a
=
op
.
overwrite_a
dtype
=
node
.
inputs
[
0
]
.
dtype
if
dtype
in
complex_dtypes
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_basic.numba_njit
def
lu
(
a
):
if
check_finite
:
if
discrete_inp
:
a
=
a
.
astype
(
out_dtype
)
elif
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to lu"
...
...
@@ -161,16 +174,22 @@ def numba_funcify_LU(op, node, **kwargs):
@numba_funcify.register
(
LUFactor
)
def
numba_funcify_LUFactor
(
op
,
node
,
**
kwargs
):
dtype
=
node
.
inputs
[
0
]
.
dtype
inp_dtype
=
node
.
inputs
[
0
]
.
type
.
numpy_dtype
if
inp_dtype
.
kind
==
"c"
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
discrete_inp
=
inp_dtype
.
kind
in
"ibu"
if
discrete_inp
and
config
.
compiler_verbose
:
print
(
"LUFactor requires casting discrete input to float"
)
# noqa: T201
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
if
dtype
in
complex_dtypes
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_basic.numba_njit
def
lu_factor
(
a
):
if
check_finite
:
if
discrete_inp
:
a
=
a
.
astype
(
out_dtype
)
elif
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to cholesky"
...
...
@@ -207,6 +226,21 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
@numba_funcify.register
(
Solve
)
def
numba_funcify_Solve
(
op
,
node
,
**
kwargs
):
A_dtype
,
b_dtype
=
(
i
.
numpy_dtype
for
i
in
node
.
inputs
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
if
A_dtype
.
kind
==
"c"
or
b_dtype
.
kind
==
"c"
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
must_cast_A
=
A_dtype
!=
out_dtype
if
must_cast_A
and
config
.
compiler_verbose
:
print
(
"Solve requires casting first input `A`"
)
# noqa: T201
must_cast_B
=
b_dtype
!=
out_dtype
if
must_cast_B
and
config
.
compiler_verbose
:
print
(
"Solve requires casting second input `b`"
)
# noqa: T201
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
assume_a
=
op
.
assume_a
lower
=
op
.
lower
check_finite
=
op
.
check_finite
...
...
@@ -214,10 +248,6 @@ def numba_funcify_Solve(op, node, **kwargs):
overwrite_b
=
op
.
overwrite_b
transposed
=
False
# TODO: Solve doesnt currently allow the transposed argument
dtype
=
node
.
inputs
[
0
]
.
dtype
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
if
assume_a
==
"gen"
:
solve_fn
=
_solve_gen
elif
assume_a
==
"sym"
:
...
...
@@ -239,6 +269,10 @@ def numba_funcify_Solve(op, node, **kwargs):
@numba_basic.numba_njit
def
solve
(
a
,
b
):
if
must_cast_A
:
a
=
a
.
astype
(
out_dtype
)
if
must_cast_B
:
b
=
b
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
...
...
@@ -263,14 +297,24 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
overwrite_b
=
op
.
overwrite_b
b_ndim
=
op
.
b_ndim
dtype
=
node
.
inputs
[
0
]
.
dtype
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
"Solve Triangular"
)
)
A_dtype
,
b_dtype
=
(
i
.
numpy_dtype
for
i
in
node
.
inputs
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
if
A_dtype
.
kind
==
"c"
or
b_dtype
.
kind
==
"c"
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
must_cast_A
=
A_dtype
!=
out_dtype
if
must_cast_A
and
config
.
compiler_verbose
:
print
(
"SolveTriangular requires casting first input `A`"
)
# noqa: T201
must_cast_B
=
b_dtype
!=
out_dtype
if
must_cast_B
and
config
.
compiler_verbose
:
print
(
"SolveTriangular requires casting second input `b`"
)
# noqa: T201
@numba_basic.numba_njit
def
solve_triangular
(
a
,
b
):
if
must_cast_A
:
a
=
a
.
astype
(
out_dtype
)
if
must_cast_B
:
b
=
b
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
...
...
@@ -302,24 +346,42 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
overwrite_b
=
op
.
overwrite_b
check_finite
=
op
.
check_finite
dtype
=
node
.
inputs
[
0
]
.
dtype
if
dtype
in
complex_dtypes
:
c_dtype
,
b_dtype
=
(
i
.
type
.
numpy_dtype
for
i
in
node
.
inputs
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
if
c_dtype
.
kind
==
"c"
or
b_dtype
.
kind
==
"c"
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
must_cast_c
=
c_dtype
!=
out_dtype
if
must_cast_c
and
config
.
compiler_verbose
:
print
(
"CholeskySolve requires casting first input `c`"
)
# noqa: T201
must_cast_b
=
b_dtype
!=
out_dtype
if
must_cast_b
and
config
.
compiler_verbose
:
print
(
"CholeskySolve requires casting second input `b`"
)
# noqa: T201
@numba_basic.numba_njit
def
cho_solve
(
c
,
b
):
if
must_cast_c
:
c
=
c
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
c
),
np
.
isnan
(
c
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to cho_solve"
)
if
must_cast_b
:
b
=
b
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to cho_solve"
)
return
_cho_solve
(
c
,
b
,
lower
=
lower
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
c
,
b
,
lower
=
lower
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
)
return
cho_solve
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论