Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
19023545
提交
19023545
authored
3月 20, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
3月 27, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor numba lapack codegen
上级
2774599e
显示空白字符变更
内嵌
并排
正在显示
17 个修改的文件
包含
1276 行增加
和
1073 行删除
+1276
-1073
basic.py
pytensor/link/numba/dispatch/basic.py
+1
-1
_LAPACK.py
pytensor/link/numba/dispatch/linalg/_LAPACK.py
+67
-0
__init__.py
pytensor/link/numba/dispatch/linalg/__init__.py
+0
-0
__init__.py
...nsor/link/numba/dispatch/linalg/decomposition/__init__.py
+0
-0
cholesky.py
...nsor/link/numba/dispatch/linalg/decomposition/cholesky.py
+66
-0
__init__.py
pytensor/link/numba/dispatch/linalg/solve/__init__.py
+0
-0
cholesky.py
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
+87
-0
general.py
pytensor/link/numba/dispatch/linalg/solve/general.py
+256
-0
norm.py
pytensor/link/numba/dispatch/linalg/solve/norm.py
+58
-0
posdef.py
pytensor/link/numba/dispatch/linalg/solve/posdef.py
+223
-0
symmetric.py
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
+228
-0
triangular.py
pytensor/link/numba/dispatch/linalg/solve/triangular.py
+116
-0
utils.py
pytensor/link/numba/dispatch/linalg/solve/utils.py
+11
-0
utils.py
pytensor/link/numba/dispatch/linalg/utils.py
+108
-0
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+47
-1068
slinalg.py
pytensor/tensor/slinalg.py
+2
-1
test_slinalg.py
tests/link/numba/test_slinalg.py
+6
-3
没有找到文件。
pytensor/link/numba/dispatch/basic.py
浏览文件 @
19023545
...
...
@@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message
=
(
"(
\x1b\\
[1m)*"
# ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs)" '
'"(numba_funcified_fgraph|store_core_outputs
|cholesky|solve|solve_triangular|cho_solve
)" '
"as it uses dynamic globals"
),
category
=
NumbaWarning
,
...
...
pytensor/link/numba/dispatch/_LAPACK.py
→
pytensor/link/numba/dispatch/
linalg/
_LAPACK.py
浏览文件 @
19023545
...
...
@@ -390,3 +390,70 @@ class _LAPACK:
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xgttrf
(
cls
,
dtype
):
"""
Compute the LU factorization of a tridiagonal matrix A using row interchanges.
Called by scipy.linalg.lu_factor
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"gttrf"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# N
float_pointer
,
# DL
float_pointer
,
# D
float_pointer
,
# DU
float_pointer
,
# DU2
_ptr_int
,
# IPIV
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xgttrs
(
cls
,
dtype
):
"""
Solve a system of linear equations A @ X = B with a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
Called by scipy.linalg.lu_solve
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"gttrs"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# TRANS
_ptr_int
,
# N
_ptr_int
,
# NRHS
float_pointer
,
# DL
float_pointer
,
# D
float_pointer
,
# DU
float_pointer
,
# DU2
_ptr_int
,
# IPIV
float_pointer
,
# B
_ptr_int
,
# LDB
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xgtcon
(
cls
,
dtype
):
"""
Estimate the reciprocal of the condition number of a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"gtcon"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# NORM
_ptr_int
,
# N
float_pointer
,
# DL
float_pointer
,
# D
float_pointer
,
# DU
float_pointer
,
# DU2
_ptr_int
,
# IPIV
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
_ptr_int
,
# IWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
pytensor/link/numba/dispatch/linalg/__init__.py
0 → 100644
浏览文件 @
19023545
pytensor/link/numba/dispatch/linalg/decomposition/__init__.py
0 → 100644
浏览文件 @
19023545
pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py
0 → 100644
浏览文件 @
19023545
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_scipy_linalg_matrix
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
return
(
linalg
.
cholesky
(
a
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
),
0
,
)
@overload
(
_cholesky
)
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
,
check_finite
=
True
):
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"cholesky"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_potrf
=
_LAPACK
()
.
numba_xpotrf
(
dtype
)
def
impl
(
A
,
lower
=
0
,
overwrite_a
=
False
,
check_finite
=
True
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
if
A
.
shape
[
-
2
]
!=
_N
:
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
if
overwrite_a
and
A
.
flags
.
f_contiguous
:
A_copy
=
A
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
numba_potrf
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
INFO
,
)
if
lower
:
for
j
in
range
(
1
,
_N
):
for
i
in
range
(
j
):
A_copy
[
i
,
j
]
=
0.0
else
:
for
j
in
range
(
_N
):
for
i
in
range
(
j
+
1
,
_N
):
A_copy
[
i
,
j
]
=
0.0
return
A_copy
,
int_ptr_to_val
(
INFO
)
return
impl
pytensor/link/numba/dispatch/linalg/solve/__init__.py
0 → 100644
浏览文件 @
19023545
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
0 → 100644
浏览文件 @
19023545
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
def
_cho_solve
(
C
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
return
linalg
.
cho_solve
(
(
C
,
lower
),
b
=
B
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
)
@overload
(
_cho_solve
)
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
ensure_lapack
()
_check_scipy_linalg_matrix
(
C
,
"cho_solve"
)
_check_scipy_linalg_matrix
(
B
,
"cho_solve"
)
dtype
=
C
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_potrs
=
_LAPACK
()
.
numba_xpotrs
(
dtype
)
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
_solve_check_input_shapes
(
C
,
B
)
_N
=
np
.
int32
(
C
.
shape
[
-
1
])
if
C
.
flags
.
f_contiguous
or
C
.
flags
.
c_contiguous
:
C_f
=
C
if
C
.
flags
.
c_contiguous
:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower
=
not
lower
else
:
C_f
=
np
.
asfortranarray
(
C
)
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
B_is_1d
=
B
.
ndim
==
1
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B
.
shape
[
-
1
])
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
numba_potrs
(
UPLO
,
N
,
NRHS
,
C_f
.
view
(
w_type
)
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
_solve_check
(
_N
,
int_ptr_to_val
(
INFO
))
if
B_is_1d
:
return
B_copy
[
...
,
0
]
return
B_copy
return
impl
pytensor/link/numba/dispatch/linalg/solve/general.py
0 → 100644
浏览文件 @
19023545
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
def
_xgecon
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return
# type: ignore
@overload
(
_xgecon
)
def
xgecon_impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
Callable
[[
np
.
ndarray
,
float
,
str
],
tuple
[
np
.
ndarray
,
int
]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"gecon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_gecon
=
_LAPACK
()
.
numba_xgecon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
A_NORM
=
np
.
array
(
A_norm
,
dtype
=
dtype
)
NORM
=
val_to_int_ptr
(
ord
(
norm
))
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
4
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
1
)
numba_gecon
(
NORM
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
A_NORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_getrf
(
A
,
overwrite_a
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for LU factorization; used by linalg.solve.
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
"""
return
# type: ignore
@overload
(
_getrf
)
def
getrf_impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
Callable
[[
np
.
ndarray
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"getrf"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrf
=
_LAPACK
()
.
numba_xgetrf
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
if
overwrite_a
and
A
.
flags
.
f_contiguous
:
A_copy
=
A
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
IPIV
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
# type: ignore
INFO
=
val_to_int_ptr
(
0
)
numba_getrf
(
M
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
INFO
)
return
A_copy
,
IPIV
,
int_ptr_to_val
(
INFO
)
return
impl
def
_getrs
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve.
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
"""
return
# type: ignore
@overload
(
_getrs
)
def
getrs_impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
,
bool
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
LU
,
"getrs"
)
_check_scipy_linalg_matrix
(
B
,
"getrs"
)
dtype
=
LU
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrs
=
_LAPACK
()
.
numba_xgetrs
(
dtype
)
def
impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
LU
.
shape
[
-
1
])
_solve_check_input_shapes
(
LU
,
B
)
B_is_1d
=
B
.
ndim
==
1
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B_copy
.
shape
[
-
1
])
TRANS
=
val_to_int_ptr
(
_trans_char_to_int
(
trans
))
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
IPIV
=
_copy_to_fortran_order
(
IPIV
)
INFO
=
val_to_int_ptr
(
0
)
numba_getrs
(
TRANS
,
N
,
NRHS
,
LU
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
if
B_is_1d
:
B_copy
=
B_copy
[
...
,
0
]
return
B_copy
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_gen
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
for users who import pytensor."""
return
linalg
.
solve
(
A
,
B
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
assume_a
=
"gen"
,
transposed
=
transposed
,
)
@overload
(
_solve_gen
)
def
solve_gen_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
if
overwrite_a
and
A
.
flags
.
c_contiguous
:
# Work with the transposed system to avoid copying A
A
=
A
.
T
transposed
=
not
transposed
order
=
"I"
if
transposed
else
"1"
norm
=
_xlange
(
A
,
order
=
order
)
N
=
A
.
shape
[
1
]
LU
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
_solve_check
(
N
,
INFO
)
X
,
INFO
=
_getrs
(
LU
=
LU
,
B
=
B
,
IPIV
=
IPIV
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
)
_solve_check
(
N
,
INFO
)
RCOND
,
INFO
=
_xgecon
(
LU
,
norm
,
"1"
)
_solve_check
(
N
,
INFO
,
True
,
RCOND
)
return
X
return
impl
pytensor/link/numba/dispatch/linalg/solve/norm.py
0 → 100644
浏览文件 @
19023545
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_scipy_linalg_matrix
def
_xlange
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
float
:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return
# type: ignore
@overload
(
_xlange
)
def
xlange_impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
Callable
[[
np
.
ndarray
,
str
],
float
]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"norm"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_lange
=
_LAPACK
()
.
numba_xlange
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
):
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
NORM
=
(
val_to_int_ptr
(
ord
(
order
))
if
order
is
not
None
else
val_to_int_ptr
(
ord
(
"1"
))
)
WORK
=
np
.
empty
(
_M
,
dtype
=
dtype
)
# type: ignore
result
=
numba_lange
(
NORM
,
M
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
WORK
.
view
(
w_type
)
.
ctypes
)
return
result
return
impl
pytensor/link/numba/dispatch/linalg/solve/posdef.py
0 → 100644
浏览文件 @
19023545
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
def
_posv
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
"""
return
# type: ignore
@overload
(
_posv
)
def
posv_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
],
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_posv
=
_LAPACK
()
.
numba_xposv
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_solve_check_input_shapes
(
A
,
B
)
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
if
overwrite_a
and
(
A
.
flags
.
f_contiguous
or
A
.
flags
.
c_contiguous
):
A_copy
=
A
if
A
.
flags
.
c_contiguous
:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower
=
not
lower
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
B_is_1d
=
B
.
ndim
==
1
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
NRHS
=
1
if
B_is_1d
else
int
(
B
.
shape
[
-
1
])
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
numba_posv
(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
if
B_is_1d
:
B_copy
=
B_copy
[
...
,
0
]
return
A_copy
,
B_copy
,
int_ptr_to_val
(
INFO
)
return
impl
def
_pocon
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return
# type: ignore
@overload
(
_pocon
)
def
pocon_impl
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"pocon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_pocon
=
_LAPACK
()
.
numba_xpocon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
anorm
:
float
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
))
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
3
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_pocon
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
ANORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_psd
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
avoid unexpected side-effects when users import pytensor."""
return
linalg
.
solve
(
A
,
B
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
transposed
=
transposed
,
assume_a
=
"pos"
,
)
@overload
(
_solve_psd
)
def
solve_psd_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
C
,
x
,
info
=
_posv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
,
check_finite
,
transposed
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
rcond
,
info
=
_pocon
(
C
,
_xlange
(
A
))
_solve_check
(
A
.
shape
[
-
1
],
info
=
info
,
lamch
=
True
,
rcond
=
rcond
)
return
x
return
impl
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
0 → 100644
浏览文件 @
19023545
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
def
_sysv
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
"""
return
# type: ignore
@overload
(
_sysv
)
def
sysv_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
],
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"sysv"
)
_check_scipy_linalg_matrix
(
B
,
"sysv"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_sysv
=
_LAPACK
()
.
numba_xsysv
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
):
_LDA
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
_solve_check_input_shapes
(
A
,
B
)
if
overwrite_a
and
(
A
.
flags
.
f_contiguous
or
A
.
flags
.
c_contiguous
):
A_copy
=
A
if
A
.
flags
.
c_contiguous
:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower
=
not
lower
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
B_is_1d
=
B
.
ndim
==
1
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B
.
shape
[
-
1
])
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
N
=
val_to_int_ptr
(
_N
)
# type: ignore
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_LDA
)
# type: ignore
IPIV
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
# type: ignore
LDB
=
val_to_int_ptr
(
_N
)
# type: ignore
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
-
1
)
INFO
=
val_to_int_ptr
(
0
)
# Workspace query
numba_sysv
(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
WORK
.
view
(
w_type
)
.
ctypes
,
LWORK
,
INFO
,
)
WS_SIZE
=
np
.
int32
(
WORK
[
0
]
.
real
)
LWORK
=
val_to_int_ptr
(
WS_SIZE
)
WORK
=
np
.
empty
(
WS_SIZE
,
dtype
=
dtype
)
# Actual solve
numba_sysv
(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
WORK
.
view
(
w_type
)
.
ctypes
,
LWORK
,
INFO
,
)
if
B_is_1d
:
B_copy
=
B_copy
[
...
,
0
]
return
A_copy
,
B_copy
,
IPIV
,
int_ptr_to_val
(
INFO
)
return
impl
def
_sycon
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return
# type: ignore
@overload
(
_sycon
)
def
sycon_impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"sycon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_sycon
=
_LAPACK
()
.
numba_xsycon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
UPLO
=
val_to_int_ptr
(
ord
(
"U"
))
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
2
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_sycon
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
ipiv
.
ctypes
,
ANORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_symmetric
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
unexpected side-effects when users import pytensor."""
return
linalg
.
solve
(
A
,
B
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
assume_a
=
"sym"
,
transposed
=
transposed
,
)
@overload
(
_solve_symmetric
)
def
solve_symmetric_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
lu
,
x
,
ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
rcond
,
info
=
_sycon
(
lu
,
ipiv
,
_xlange
(
A
,
order
=
"I"
))
_solve_check
(
A
.
shape
[
-
1
],
info
,
True
,
rcond
)
return
x
return
impl
pytensor/link/numba/dispatch/linalg/solve/triangular.py
0 → 100644
浏览文件 @
19023545
import
numpy
as
np
from
numba.core
import
types
from
numba.core.extending
import
overload
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
def
_solve_triangular
(
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
b_ndim
=
1
,
overwrite_b
=
False
):
"""
Thin wrapper around scipy.linalg.solve_triangular.
This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who
import pytensor.
The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not
used by scipy.linalg.solve_triangular.
"""
return
linalg
.
solve_triangular
(
A
,
B
,
trans
=
trans
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
)
@overload
(
_solve_triangular
)
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve_triangular"
)
_check_scipy_linalg_matrix
(
B
,
"solve_triangular"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_trtrs
=
_LAPACK
()
.
numba_xtrtrs
(
dtype
)
if
isinstance
(
dtype
,
types
.
Complex
):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
raise
TypeError
(
"This function is not expected to work with complex numbers yet"
)
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d
=
B
.
ndim
==
1
if
A
.
flags
.
f_contiguous
or
(
A
.
flags
.
c_contiguous
and
trans
in
(
0
,
1
)):
A_f
=
A
if
A
.
flags
.
c_contiguous
:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
# Is this valid for complex matrices that were .conj().mT by PyTensor?
lower
=
not
lower
trans
=
1
-
trans
else
:
A_f
=
np
.
asfortranarray
(
A
)
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B_copy
.
shape
[
-
1
])
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
TRANS
=
val_to_int_ptr
(
_trans_char_to_int
(
trans
))
DIAG
=
val_to_int_ptr
(
ord
(
"U"
)
if
unit_diagonal
else
ord
(
"N"
))
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
numba_trtrs
(
UPLO
,
TRANS
,
DIAG
,
N
,
NRHS
,
A_f
.
view
(
w_type
)
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
_solve_check
(
int_ptr_to_val
(
LDA
),
int_ptr_to_val
(
INFO
))
if
B_is_1d
:
return
B_copy
[
...
,
0
]
return
B_copy
return
impl
pytensor/link/numba/dispatch/linalg/solve/utils.py
0 → 100644
浏览文件 @
19023545
from
scipy
import
linalg
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_solve_check_input_shapes
(
A
,
B
):
if
A
.
shape
[
0
]
!=
B
.
shape
[
0
]:
raise
linalg
.
LinAlgError
(
"Dimensions of A and B do not conform"
)
if
A
.
shape
[
-
2
]
!=
A
.
shape
[
-
1
]:
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
pytensor/link/numba/dispatch/linalg/utils.py
0 → 100644
浏览文件 @
19023545
from
collections.abc
import
Callable
import
numba
from
numba.core
import
types
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numpy.linalg
import
LinAlgError
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
val_to_int_ptr
,
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_copy_to_fortran_order_even_if_1d
(
x
):
# Numba's _copy_to_fortran_order doesn't do anything for vectors
return
x
.
copy
()
if
x
.
ndim
==
1
else
_copy_to_fortran_order
(
x
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_trans_char_to_int
(
trans
):
if
trans
not
in
[
0
,
1
,
2
]:
raise
ValueError
(
'Parameter "trans" should be one of 0, 1, 2'
)
if
trans
==
0
:
return
ord
(
"N"
)
elif
trans
==
1
:
return
ord
(
"T"
)
else
:
return
ord
(
"C"
)
def
_check_scipy_linalg_matrix
(
a
,
func_name
):
"""
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
"""
prefix
=
"scipy.linalg"
# Unpack optional type
if
isinstance
(
a
,
types
.
Optional
):
a
=
a
.
type
if
not
isinstance
(
a
,
types
.
Array
):
msg
=
f
"{prefix}.{func_name}() only supported for array types"
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
a
.
ndim
not
in
[
1
,
2
]:
msg
=
(
f
"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}."
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
not
isinstance
(
a
.
dtype
,
types
.
Float
|
types
.
Complex
):
msg
=
f
"{prefix}.{func_name}() only supported on float and complex arrays."
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_solve_check
(
n
,
info
,
lamch
=
False
,
rcond
=
None
):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if
info
<
0
:
# TODO: figure out how to do an fstring here
msg
=
"LAPACK reported an illegal value in input"
raise
ValueError
(
msg
)
elif
0
<
info
:
raise
LinAlgError
(
"Matrix is singular."
)
if
lamch
:
E
=
_xlamch
(
"E"
)
if
rcond
<
E
:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print
(
# noqa: T201
"Ill-conditioned matrix, rcond="
,
rcond
,
", result may not be accurate."
)
def
_xlamch
(
kind
:
str
=
"E"
):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload
(
_xlamch
)
def
xlamch_impl
(
kind
:
str
=
"E"
)
->
Callable
[[
str
],
float
]:
"""
Compute the machine precision for a given floating point type.
"""
from
pytensor
import
config
ensure_lapack
()
w_type
=
_get_underlying_float
(
config
.
floatX
)
if
w_type
==
"float32"
:
dtype
=
types
.
float32
elif
w_type
==
"float64"
:
dtype
=
types
.
float64
else
:
raise
NotImplementedError
(
"Unsupported dtype"
)
numba_lamch
=
_LAPACK
()
.
numba_xlamch
(
dtype
)
def
impl
(
kind
:
str
=
"E"
)
->
float
:
KIND
=
val_to_int_ptr
(
ord
(
kind
))
return
numba_lamch
(
KIND
)
# type: ignore
return
impl
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
19023545
import
warnings
from
collections.abc
import
Callable
import
numba
import
numpy
as
np
from
numba.core
import
types
from
numba.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numpy.linalg
import
LinAlgError
from
scipy
import
linalg
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.solve.cholesky
import
_cho_solve
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_solve_gen
from
pytensor.link.numba.dispatch.linalg.solve.posdef
import
_solve_psd
from
pytensor.link.numba.dispatch.linalg.solve.symmetric
import
_solve_symmetric
from
pytensor.link.numba.dispatch.linalg.solve.triangular
import
_solve_triangular
from
pytensor.tensor.slinalg
import
(
BlockDiagonal
,
Cholesky
,
...
...
@@ -33,265 +25,6 @@ _COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_copy_to_fortran_order_even_if_1d
(
x
):
# Numba's _copy_to_fortran_order doesn't do anything for vectors
return
x
.
copy
()
if
x
.
ndim
==
1
else
_copy_to_fortran_order
(
x
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_solve_check
(
n
,
info
,
lamch
=
False
,
rcond
=
None
):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if
info
<
0
:
# TODO: figure out how to do an fstring here
msg
=
"LAPACK reported an illegal value in input"
raise
ValueError
(
msg
)
elif
0
<
info
:
raise
LinAlgError
(
"Matrix is singular."
)
if
lamch
:
E
=
_xlamch
(
"E"
)
if
rcond
<
E
:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print
(
# noqa: T201
"Ill-conditioned matrix, rcond="
,
rcond
,
", result may not be accurate."
)
def
_check_scipy_linalg_matrix
(
a
,
func_name
):
"""
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
"""
prefix
=
"scipy.linalg"
# Unpack optional type
if
isinstance
(
a
,
types
.
Optional
):
a
=
a
.
type
if
not
isinstance
(
a
,
types
.
Array
):
msg
=
f
"{prefix}.{func_name}() only supported for array types"
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
a
.
ndim
not
in
[
1
,
2
]:
msg
=
(
f
"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}."
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
if
not
isinstance
(
a
.
dtype
,
types
.
Float
|
types
.
Complex
):
msg
=
f
"{prefix}.{func_name}() only supported on float and complex arrays."
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
def
_solve_triangular
(
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
b_ndim
=
1
,
overwrite_b
=
False
):
"""
Thin wrapper around scipy.linalg.solve_triangular.
This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who
import pytensor.
The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not
used by scipy.linalg.solve_triangular.
"""
return
linalg
.
solve_triangular
(
A
,
B
,
trans
=
trans
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_trans_char_to_int
(
trans
):
if
trans
not
in
[
0
,
1
,
2
]:
raise
ValueError
(
'Parameter "trans" should be one of 0, 1, 2'
)
if
trans
==
0
:
return
ord
(
"N"
)
elif
trans
==
1
:
return
ord
(
"T"
)
else
:
return
ord
(
"C"
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_solve_check_input_shapes
(
A
,
B
):
if
A
.
shape
[
0
]
!=
B
.
shape
[
0
]:
raise
linalg
.
LinAlgError
(
"Dimensions of A and B do not conform"
)
if
A
.
shape
[
-
2
]
!=
A
.
shape
[
-
1
]:
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
@overload
(
_solve_triangular
)
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve_triangular"
)
_check_scipy_linalg_matrix
(
B
,
"solve_triangular"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_trtrs
=
_LAPACK
()
.
numba_xtrtrs
(
dtype
)
if
isinstance
(
dtype
,
types
.
Complex
):
# If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
raise
TypeError
(
"This function is not expected to work with complex numbers"
)
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d
=
B
.
ndim
==
1
if
A
.
flags
.
f_contiguous
or
(
A
.
flags
.
c_contiguous
and
trans
in
(
0
,
1
)):
A_f
=
A
if
A
.
flags
.
c_contiguous
:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
# Is this valid for complex matrices that were .conj().mT by PyTensor?
lower
=
not
lower
trans
=
1
-
trans
else
:
A_f
=
np
.
asfortranarray
(
A
)
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B_copy
.
shape
[
-
1
])
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
TRANS
=
val_to_int_ptr
(
_trans_char_to_int
(
trans
))
DIAG
=
val_to_int_ptr
(
ord
(
"U"
)
if
unit_diagonal
else
ord
(
"N"
))
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
numba_trtrs
(
UPLO
,
TRANS
,
DIAG
,
N
,
NRHS
,
A_f
.
view
(
w_type
)
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
_solve_check
(
int_ptr_to_val
(
LDA
),
int_ptr_to_val
(
INFO
))
if
B_is_1d
:
return
B_copy
[
...
,
0
]
return
B_copy
return
impl
@numba_funcify.register
(
SolveTriangular
)
def
numba_funcify_SolveTriangular
(
op
,
node
,
**
kwargs
):
lower
=
op
.
lower
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
overwrite_b
=
op
.
overwrite_b
b_ndim
=
op
.
b_ndim
dtype
=
node
.
inputs
[
0
]
.
dtype
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
"Solve Triangular"
)
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
solve_triangular
(
a
,
b
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to solve_triangular"
)
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to solve_triangular"
)
res
=
_solve_triangular
(
a
,
b
,
trans
=
0
,
# transposing is handled explicitly on the graph, so we never use this argument
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
b_ndim
=
b_ndim
,
)
return
res
return
solve_triangular
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
return
(
linalg
.
cholesky
(
a
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
),
0
,
)
@overload
(
_cholesky
)
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
,
check_finite
=
True
):
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"cholesky"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_potrf
=
_LAPACK
()
.
numba_xpotrf
(
dtype
)
def
impl
(
A
,
lower
=
0
,
overwrite_a
=
False
,
check_finite
=
True
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
if
A
.
shape
[
-
2
]
!=
_N
:
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
if
overwrite_a
and
A
.
flags
.
f_contiguous
:
A_copy
=
A
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
numba_potrf
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
INFO
,
)
if
lower
:
for
j
in
range
(
1
,
_N
):
for
i
in
range
(
j
):
A_copy
[
i
,
j
]
=
0.0
else
:
for
j
in
range
(
_N
):
for
i
in
range
(
j
+
1
,
_N
):
A_copy
[
i
,
j
]
=
0.0
return
A_copy
,
int_ptr_to_val
(
INFO
)
return
impl
@numba_funcify.register
(
Cholesky
)
def
numba_funcify_Cholesky
(
op
,
node
,
**
kwargs
):
"""
...
...
@@ -309,8 +42,8 @@ def numba_funcify_Cholesky(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_
basic.numba_njit
(
inline
=
"always"
)
def
nb_
cholesky
(
a
):
@numba_
njit
def
cholesky
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
...
...
@@ -333,7 +66,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return
res
return
nb_
cholesky
return
cholesky
@numba_funcify.register
(
BlockDiagonal
)
...
...
@@ -341,7 +74,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype
=
node
.
outputs
[
0
]
.
dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_
basic.numba_njit
(
inline
=
"never"
)
@numba_
njit
def
block_diag
(
*
arrs
):
shapes
=
np
.
array
([
a
.
shape
for
a
in
arrs
],
dtype
=
"int"
)
out_shape
=
[
int
(
s
)
for
s
in
np
.
sum
(
shapes
,
axis
=
0
)]
...
...
@@ -359,731 +92,6 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
return
block_diag
def
_xlamch
(
kind
:
str
=
"E"
):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload
(
_xlamch
)
def
xlamch_impl
(
kind
:
str
=
"E"
)
->
Callable
[[
str
],
float
]:
"""
Compute the machine precision for a given floating point type.
"""
from
pytensor
import
config
ensure_lapack
()
w_type
=
_get_underlying_float
(
config
.
floatX
)
if
w_type
==
"float32"
:
dtype
=
types
.
float32
elif
w_type
==
"float64"
:
dtype
=
types
.
float64
else
:
raise
NotImplementedError
(
"Unsupported dtype"
)
numba_lamch
=
_LAPACK
()
.
numba_xlamch
(
dtype
)
def
impl
(
kind
:
str
=
"E"
)
->
float
:
KIND
=
val_to_int_ptr
(
ord
(
kind
))
return
numba_lamch
(
KIND
)
# type: ignore
return
impl
def
_xlange
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
float
:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return
# type: ignore
@overload
(
_xlange
)
def
xlange_impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
Callable
[[
np
.
ndarray
,
str
],
float
]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"norm"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_lange
=
_LAPACK
()
.
numba_xlange
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
):
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
NORM
=
(
val_to_int_ptr
(
ord
(
order
))
if
order
is
not
None
else
val_to_int_ptr
(
ord
(
"1"
))
)
WORK
=
np
.
empty
(
_M
,
dtype
=
dtype
)
# type: ignore
result
=
numba_lange
(
NORM
,
M
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
WORK
.
view
(
w_type
)
.
ctypes
)
return
result
return
impl
def
_xgecon
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return
# type: ignore
@overload
(
_xgecon
)
def
xgecon_impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
Callable
[[
np
.
ndarray
,
float
,
str
],
tuple
[
np
.
ndarray
,
int
]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"gecon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_gecon
=
_LAPACK
()
.
numba_xgecon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
A_NORM
=
np
.
array
(
A_norm
,
dtype
=
dtype
)
NORM
=
val_to_int_ptr
(
ord
(
norm
))
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
4
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
1
)
numba_gecon
(
NORM
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
A_NORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_getrf
(
A
,
overwrite_a
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for LU factorization; used by linalg.solve.
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
"""
return
# type: ignore
@overload
(
_getrf
)
def
getrf_impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
Callable
[[
np
.
ndarray
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"getrf"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrf
=
_LAPACK
()
.
numba_xgetrf
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
if
overwrite_a
and
A
.
flags
.
f_contiguous
:
A_copy
=
A
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
IPIV
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
# type: ignore
INFO
=
val_to_int_ptr
(
0
)
numba_getrf
(
M
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
INFO
)
return
A_copy
,
IPIV
,
int_ptr_to_val
(
INFO
)
return
impl
def
_getrs
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve.
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
"""
return
# type: ignore
@overload
(
_getrs
)
def
getrs_impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
,
bool
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
LU
,
"getrs"
)
_check_scipy_linalg_matrix
(
B
,
"getrs"
)
dtype
=
LU
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrs
=
_LAPACK
()
.
numba_xgetrs
(
dtype
)
def
impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
LU
.
shape
[
-
1
])
_solve_check_input_shapes
(
LU
,
B
)
B_is_1d
=
B
.
ndim
==
1
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B_copy
.
shape
[
-
1
])
TRANS
=
val_to_int_ptr
(
_trans_char_to_int
(
trans
))
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
IPIV
=
_copy_to_fortran_order
(
IPIV
)
INFO
=
val_to_int_ptr
(
0
)
numba_getrs
(
TRANS
,
N
,
NRHS
,
LU
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
if
B_is_1d
:
B_copy
=
B_copy
[
...
,
0
]
return
B_copy
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_gen
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
for users who import pytensor."""
return
linalg
.
solve
(
A
,
B
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
assume_a
=
"gen"
,
transposed
=
transposed
,
)
@overload
(
_solve_gen
)
def
solve_gen_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
if
overwrite_a
and
A
.
flags
.
c_contiguous
:
# Work with the transposed system to avoid copying A
A
=
A
.
T
transposed
=
not
transposed
order
=
"I"
if
transposed
else
"1"
norm
=
_xlange
(
A
,
order
=
order
)
N
=
A
.
shape
[
1
]
LU
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
_solve_check
(
N
,
INFO
)
X
,
INFO
=
_getrs
(
LU
=
LU
,
B
=
B
,
IPIV
=
IPIV
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
)
_solve_check
(
N
,
INFO
)
RCOND
,
INFO
=
_xgecon
(
LU
,
norm
,
"1"
)
_solve_check
(
N
,
INFO
,
True
,
RCOND
)
return
X
return
impl
def
_sysv
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
"""
return
# type: ignore
@overload
(
_sysv
)
def
sysv_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
],
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"sysv"
)
_check_scipy_linalg_matrix
(
B
,
"sysv"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_sysv
=
_LAPACK
()
.
numba_xsysv
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
):
_LDA
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
_solve_check_input_shapes
(
A
,
B
)
if
overwrite_a
and
(
A
.
flags
.
f_contiguous
or
A
.
flags
.
c_contiguous
):
A_copy
=
A
if
A
.
flags
.
c_contiguous
:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower
=
not
lower
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
B_is_1d
=
B
.
ndim
==
1
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B
.
shape
[
-
1
])
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
N
=
val_to_int_ptr
(
_N
)
# type: ignore
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_LDA
)
# type: ignore
IPIV
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
# type: ignore
LDB
=
val_to_int_ptr
(
_N
)
# type: ignore
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
-
1
)
INFO
=
val_to_int_ptr
(
0
)
# Workspace query
numba_sysv
(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
WORK
.
view
(
w_type
)
.
ctypes
,
LWORK
,
INFO
,
)
WS_SIZE
=
np
.
int32
(
WORK
[
0
]
.
real
)
LWORK
=
val_to_int_ptr
(
WS_SIZE
)
WORK
=
np
.
empty
(
WS_SIZE
,
dtype
=
dtype
)
# Actual solve
numba_sysv
(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
WORK
.
view
(
w_type
)
.
ctypes
,
LWORK
,
INFO
,
)
if
B_is_1d
:
B_copy
=
B_copy
[
...
,
0
]
return
A_copy
,
B_copy
,
IPIV
,
int_ptr_to_val
(
INFO
)
return
impl
def
_sycon
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return
# type: ignore
@overload
(
_sycon
)
def
sycon_impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"sycon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_sycon
=
_LAPACK
()
.
numba_xsycon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
UPLO
=
val_to_int_ptr
(
ord
(
"U"
))
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
2
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_sycon
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
ipiv
.
ctypes
,
ANORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_symmetric
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
unexpected side-effects when users import pytensor."""
return
linalg
.
solve
(
A
,
B
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
assume_a
=
"sym"
,
transposed
=
transposed
,
)
@overload
(
_solve_symmetric
)
def
solve_symmetric_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
lu
,
x
,
ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
rcond
,
info
=
_sycon
(
lu
,
ipiv
,
_xlange
(
A
,
order
=
"I"
))
_solve_check
(
A
.
shape
[
-
1
],
info
,
True
,
rcond
)
return
x
return
impl
def
_posv
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
"""
return
# type: ignore
@overload
(
_posv
)
def
posv_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
],
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_posv
=
_LAPACK
()
.
numba_xposv
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_solve_check_input_shapes
(
A
,
B
)
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
if
overwrite_a
and
(
A
.
flags
.
f_contiguous
or
A
.
flags
.
c_contiguous
):
A_copy
=
A
if
A
.
flags
.
c_contiguous
:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower
=
not
lower
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
B_is_1d
=
B
.
ndim
==
1
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
NRHS
=
1
if
B_is_1d
else
int
(
B
.
shape
[
-
1
])
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
numba_posv
(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
if
B_is_1d
:
B_copy
=
B_copy
[
...
,
0
]
return
A_copy
,
B_copy
,
int_ptr_to_val
(
INFO
)
return
impl
def
_pocon
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return
# type: ignore
@overload
(
_pocon
)
def
pocon_impl
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"pocon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_pocon
=
_LAPACK
()
.
numba_xpocon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
anorm
:
float
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
))
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
3
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_pocon
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
ANORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_psd
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
avoid unexpected side-effects when users import pytensor."""
return
linalg
.
solve
(
A
,
B
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
transposed
=
transposed
,
assume_a
=
"pos"
,
)
@overload
(
_solve_psd
)
def
solve_psd_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
C
,
x
,
info
=
_posv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
,
check_finite
,
transposed
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
rcond
,
info
=
_pocon
(
C
,
_xlange
(
A
))
_solve_check
(
A
.
shape
[
-
1
],
info
=
info
,
lamch
=
True
,
rcond
=
rcond
)
return
x
return
impl
@numba_funcify.register
(
Solve
)
def
numba_funcify_Solve
(
op
,
node
,
**
kwargs
):
assume_a
=
op
.
assume_a
...
...
@@ -1109,12 +117,12 @@ def numba_funcify_Solve(op, node, **kwargs):
else
:
warnings
.
warn
(
f
"Numba assume_a={assume_a} not implemented. Falling back to general solve.
\n
"
f
"If appropriate, you may want to set assume_a to one of 'sym', 'pos',
or 'he
r' to improve performance."
,
f
"If appropriate, you may want to set assume_a to one of 'sym', 'pos',
'her', or 'triangula
r' to improve performance."
,
UserWarning
,
)
solve_fn
=
_solve_gen
@numba_
basic.numba_njit
(
inline
=
"always"
)
@numba_
njit
def
solve
(
a
,
b
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
@@ -1132,74 +140,45 @@ def numba_funcify_Solve(op, node, **kwargs):
return
solve
def
_cho_solve
(
C
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
return
linalg
.
cho_solve
(
(
C
,
lower
),
b
=
B
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
)
@overload
(
_cho_solve
)
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
ensure_lapack
()
_check_scipy_linalg_matrix
(
C
,
"cho_solve"
)
_check_scipy_linalg_matrix
(
B
,
"cho_solve"
)
dtype
=
C
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_potrs
=
_LAPACK
()
.
numba_xpotrs
(
dtype
)
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
_solve_check_input_shapes
(
C
,
B
)
_N
=
np
.
int32
(
C
.
shape
[
-
1
])
if
C
.
flags
.
f_contiguous
or
C
.
flags
.
c_contiguous
:
C_f
=
C
if
C
.
flags
.
c_contiguous
:
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
lower
=
not
lower
else
:
C_f
=
np
.
asfortranarray
(
C
)
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
B_is_1d
=
B
.
ndim
==
1
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B
.
shape
[
-
1
])
@numba_funcify.register
(
SolveTriangular
)
def
numba_funcify_SolveTriangular
(
op
,
node
,
**
kwargs
):
lower
=
op
.
lower
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
overwrite_b
=
op
.
overwrite_b
b_ndim
=
op
.
b_ndim
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
dtype
=
node
.
inputs
[
0
]
.
dtype
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
"Solve Triangular"
)
)
numba_potrs
(
UPLO
,
N
,
NRHS
,
C_f
.
view
(
w_type
)
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
@numba_njit
def
solve_triangular
(
a
,
b
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to solve_triangular"
)
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to solve_triangular"
)
_solve_check
(
_N
,
int_ptr_to_val
(
INFO
))
res
=
_solve_triangular
(
a
,
b
,
trans
=
0
,
# transposing is handled explicitly on the graph, so we never use this argument
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
b_ndim
=
b_ndim
,
)
if
B_is_1d
:
return
B_copy
[
...
,
0
]
return
B_copy
return
res
return
impl
return
solve_triangular
@numba_funcify.register
(
CholeskySolve
)
...
...
@@ -1212,7 +191,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_
basic.numba_njit
(
inline
=
"always"
)
@numba_
njit
def
cho_solve
(
c
,
b
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
c
),
np
.
isnan
(
c
))):
...
...
pytensor/tensor/slinalg.py
浏览文件 @
19023545
...
...
@@ -566,7 +566,8 @@ class Solve(SolveBase):
if
1
in
allowed_inplace_inputs
:
# Give preference to overwrite_b
new_props
[
"overwrite_b"
]
=
True
else
:
# allowed inputs == [0]
# We can't overwrite_a if we're assuming tridiagonal
elif
not
self
.
assume_a
==
"tridiagonal"
:
# allowed inputs == [0]
new_props
[
"overwrite_a"
]
=
True
return
type
(
self
)(
**
new_props
)
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
19023545
...
...
@@ -12,6 +12,8 @@ from pytensor.tensor.slinalg import Cholesky, CholeskySolve, Solve, SolveTriangu
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
numba_inplace_mode
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
numba
=
pytest
.
importorskip
(
"numba"
)
floatX
=
config
.
floatX
...
...
@@ -22,7 +24,7 @@ rng = np.random.default_rng(42849)
def
test_lamch
():
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.
slinalg
import
_xlamch
from
pytensor.link.numba.dispatch.
linalg.utils
import
_xlamch
@numba.njit
()
def
xlamch
(
kind
):
...
...
@@ -45,7 +47,7 @@ def test_xlange(ord_numba, ord_scipy):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.
slinalg
import
_xlange
from
pytensor.link.numba.dispatch.
linalg.solve.norm
import
_xlange
@numba.njit
()
def
xlange
(
x
,
ord
):
...
...
@@ -60,7 +62,8 @@ def test_xgecon(ord_numba, ord_scipy):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.slinalg
import
_xgecon
,
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_xgecon
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
@numba.njit
()
def
gecon
(
x
,
norm
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论