Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
672a4829
提交
672a4829
authored
1月 08, 2026
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
1月 11, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Do not raise in linalg Ops
上级
b2d8bc24
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
23 个修改的文件
包含
145 行增加
和
888 行删除
+145
-888
slinalg.py
pytensor/link/jax/dispatch/slinalg.py
+4
-10
_LAPACK.py
pytensor/link/numba/dispatch/linalg/_LAPACK.py
+0
-201
cholesky.py
...nsor/link/numba/dispatch/linalg/decomposition/cholesky.py
+11
-13
lu.py
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
+6
-15
lu_factor.py
...sor/link/numba/dispatch/linalg/decomposition/lu_factor.py
+4
-3
qr.py
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
+12
-34
cholesky.py
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
+9
-8
general.py
pytensor/link/numba/dispatch/linalg/solve/general.py
+12
-74
lu_solve.py
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
+5
-12
norm.py
pytensor/link/numba/dispatch/linalg/solve/norm.py
+0
-55
posdef.py
pytensor/link/numba/dispatch/linalg/solve/posdef.py
+6
-67
symmetric.py
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
+5
-59
triangular.py
pytensor/link/numba/dispatch/linalg/solve/triangular.py
+6
-8
tridiagonal.py
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
+16
-92
utils.py
pytensor/link/numba/dispatch/linalg/utils.py
+2
-64
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+0
-0
tensor_basic.py
pytensor/link/numba/dispatch/tensor_basic.py
+2
-2
rewriting.py
pytensor/tensor/_linalg/solve/rewriting.py
+4
-18
slinalg.py
pytensor/tensor/slinalg.py
+0
-0
linalg.py
pytensor/xtensor/linalg.py
+8
-19
test_slinalg.py
tests/link/numba/test_slinalg.py
+14
-75
test_rewriting.py
tests/tensor/linalg/test_rewriting.py
+0
-44
test_slinalg.py
tests/tensor/test_slinalg.py
+19
-15
没有找到文件。
pytensor/link/jax/dispatch/slinalg.py
浏览文件 @
672a4829
...
...
@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs):
def
jax_funcify_SolveTriangular
(
op
,
**
kwargs
):
lower
=
op
.
lower
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
def
solve_triangular
(
A
,
b
):
return
jax
.
scipy
.
linalg
.
solve_triangular
(
...
...
@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
lower
=
lower
,
trans
=
0
,
# this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal
=
unit_diagonal
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
)
return
solve_triangular
...
...
@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs):
def
jax_funcify_LU
(
op
,
**
kwargs
):
permute_l
=
op
.
permute_l
p_indices
=
op
.
p_indices
check_finite
=
op
.
check_finite
if
p_indices
:
raise
ValueError
(
"JAX does not support the p_indices argument"
)
def
lu
(
*
inputs
):
return
jax
.
scipy
.
linalg
.
lu
(
*
inputs
,
permute_l
=
permute_l
,
check_finite
=
check_finite
)
return
jax
.
scipy
.
linalg
.
lu
(
*
inputs
,
permute_l
=
permute_l
,
check_finite
=
False
)
return
lu
@jax_funcify.register
(
LUFactor
)
def
jax_funcify_LUFactor
(
op
,
**
kwargs
):
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
def
lu_factor
(
a
):
return
jax
.
scipy
.
linalg
.
lu_factor
(
a
,
check_finite
=
check_finit
e
,
overwrite_a
=
overwrite_a
a
,
check_finite
=
Fals
e
,
overwrite_a
=
overwrite_a
)
return
lu_factor
...
...
@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs):
@jax_funcify.register
(
CholeskySolve
)
def
jax_funcify_ChoSolve
(
op
,
**
kwargs
):
lower
=
op
.
lower
check_finite
=
op
.
check_finite
overwrite_b
=
op
.
overwrite_b
def
cho_solve
(
c
,
b
):
return
jax
.
scipy
.
linalg
.
cho_solve
(
(
c
,
lower
),
b
,
check_finite
=
check_finit
e
,
overwrite_b
=
overwrite_b
(
c
,
lower
),
b
,
check_finite
=
Fals
e
,
overwrite_b
=
overwrite_b
)
return
cho_solve
...
...
pytensor/link/numba/dispatch/linalg/_LAPACK.py
浏览文件 @
672a4829
...
...
@@ -263,122 +263,6 @@ class _LAPACK:
return
potrs
@classmethod
def
numba_xlange
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve, but doesn't correspond to any Op in pytensor.
"""
kind
=
get_blas_kind
(
dtype
)
float_type
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
False
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
True
)
unique_func_name
=
f
"scipy.lapack.{kind}lange"
@numba_basic.numba_njit
def
get_lange_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"lange"
)
return
ptr
lange_function_type
=
types
.
FunctionType
(
float_type
(
nb_i32p
,
# NORM
nb_i32p
,
# M
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# WORK
)
)
@numba_basic.numba_njit
def
lange
(
NORM
,
M
,
N
,
A
,
LDA
,
WORK
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_lange_pointer
,
func_type_ref
=
lange_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
return
fn
(
NORM
,
M
,
N
,
A
,
LDA
,
WORK
)
return
lange
@classmethod
def
numba_xlamch
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Determine machine precision for floating point arithmetic.
"""
kind
=
get_blas_kind
(
dtype
)
float_type
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
False
)
unique_func_name
=
f
"scipy.lapack.{kind}lamch"
@numba_basic.numba_njit
def
get_lamch_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"lamch"
)
return
ptr
lamch_function_type
=
types
.
FunctionType
(
float_type
(
# Return type
nb_i32p
,
# CMACH
)
)
@numba_basic.numba_njit
def
lamch
(
CMACH
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_lamch_pointer
,
func_type_ref
=
lamch_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
res
=
fn
(
CMACH
)
return
res
return
lamch
@classmethod
def
numba_xgecon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}gecon"
@numba_basic.numba_njit
def
get_gecon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"gecon"
)
return
ptr
gecon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# NORM
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
gecon
(
NORM
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_gecon_pointer
,
func_type_ref
=
gecon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
NORM
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
gecon
@classmethod
def
numba_xgetrf
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
...
...
@@ -506,91 +390,6 @@ class _LAPACK:
return
sysv
@classmethod
def
numba_xsycon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}sycon"
@numba_basic.numba_njit
def
get_sycon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"sycon"
)
return
ptr
sycon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# UPLO
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
nb_i32p
,
# IPIV
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
sycon
(
UPLO
,
N
,
A
,
LDA
,
IPIV
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_sycon_pointer
,
func_type_ref
=
sycon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
UPLO
,
N
,
A
,
LDA
,
IPIV
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
sycon
@classmethod
def
numba_xpocon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}pocon"
@numba_basic.numba_njit
def
get_pocon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"pocon"
)
return
ptr
pocon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# UPLO
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
pocon
(
UPLO
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_pocon_pointer
,
func_type_ref
=
pocon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
UPLO
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
pocon
@classmethod
def
numba_xposv
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
...
...
pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py
浏览文件 @
672a4829
...
...
@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
return
(
linalg
.
cholesky
(
a
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
),
0
,
)
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
):
return
linalg
.
cholesky
(
a
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
check_finite
=
False
)
@overload
(
_cholesky
)
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
,
check_finite
=
True
):
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
):
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cholesky"
)
dtype
=
A
.
dtype
numba_potrf
=
_LAPACK
()
.
numba_xpotrf
(
dtype
)
def
impl
(
A
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
def
impl
(
A
,
lower
=
False
,
overwrite_a
=
False
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
if
A
.
shape
[
-
2
]
!=
_N
:
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
...
...
@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
INFO
,
)
if
int_ptr_to_val
(
INFO
)
!=
0
:
A_copy
=
np
.
full_like
(
A_copy
,
np
.
nan
)
return
A_copy
if
lower
:
for
j
in
range
(
1
,
_N
):
for
i
in
range
(
j
):
...
...
@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
for
i
in
range
(
j
+
1
,
_N
):
A_copy
[
i
,
j
]
=
0.0
info_int
=
int_ptr_to_val
(
INFO
)
if
transposed
:
return
A_copy
.
T
,
info_int
return
A_copy
,
info_int
return
A_copy
.
T
else
:
return
A_copy
return
impl
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
浏览文件 @
672a4829
...
...
@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
def
_lu_1
(
a
:
np
.
ndarray
,
permute_l
:
Literal
[
True
],
check_finite
:
bool
,
p_indices
:
Literal
[
False
],
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
...
@@ -52,7 +51,7 @@ def _lu_1(
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
...
...
@@ -61,7 +60,6 @@ def _lu_1(
def
_lu_2
(
a
:
np
.
ndarray
,
permute_l
:
Literal
[
False
],
check_finite
:
bool
,
p_indices
:
Literal
[
True
],
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
...
...
@@ -74,7 +72,7 @@ def _lu_2(
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
...
...
@@ -83,7 +81,6 @@ def _lu_2(
def
_lu_3
(
a
:
np
.
ndarray
,
permute_l
:
Literal
[
False
],
check_finite
:
bool
,
p_indices
:
Literal
[
False
],
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
...
@@ -96,7 +93,7 @@ def _lu_3(
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
...
...
@@ -106,11 +103,10 @@ def _lu_3(
def
lu_impl_1
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...
...
@@ -123,7 +119,6 @@ def lu_impl_1(
def
impl
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
...
@@ -137,10 +132,9 @@ def lu_impl_1(
def
lu_impl_2
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
)
->
Callable
[[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
...
...
@@ -153,7 +147,6 @@ def lu_impl_2(
def
impl
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
...
...
@@ -169,11 +162,10 @@ def lu_impl_2(
def
lu_impl_3
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...
...
@@ -186,7 +178,6 @@ def lu_impl_3(
def
impl
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
浏览文件 @
672a4829
...
...
@@ -79,11 +79,12 @@ def lu_factor_impl(
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_factor"
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
A_copy
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
A_copy
,
IPIV
,
info
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
IPIV
-=
1
# LAPACK uses 1-based indexing, convert to 0-based
if
INFO
!=
0
:
raise
np
.
linalg
.
LinAlgError
(
"LU decomposition failed"
)
if
info
!=
0
:
A_copy
=
np
.
full_like
(
A_copy
,
np
.
nan
)
return
A_copy
,
IPIV
return
impl
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
浏览文件 @
672a4829
...
...
@@ -228,7 +228,6 @@ def _qr_full_pivot(
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -243,7 +242,7 @@ def _qr_full_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -253,7 +252,6 @@ def _qr_full_no_pivot(
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -267,7 +265,7 @@ def _qr_full_no_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -277,7 +275,6 @@ def _qr_r_pivot(
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -291,7 +288,7 @@ def _qr_r_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -301,7 +298,6 @@ def _qr_r_no_pivot(
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -315,7 +311,7 @@ def _qr_r_no_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -325,7 +321,6 @@ def _qr_raw_no_pivot(
mode
:
Literal
[
"raw"
]
=
"raw"
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -339,7 +334,7 @@ def _qr_raw_no_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -351,7 +346,6 @@ def _qr_raw_pivot(
mode
:
Literal
[
"raw"
]
=
"raw"
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -365,7 +359,7 @@ def _qr_raw_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -373,9 +367,7 @@ def _qr_raw_pivot(
@overload
(
_qr_full_pivot
)
def
qr_full_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
def
qr_full_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
...
...
@@ -395,7 +387,6 @@ def qr_full_pivot_impl(
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
@@ -529,7 +520,7 @@ def qr_full_pivot_impl(
@overload
(
_qr_full_no_pivot
)
def
qr_full_no_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
x
,
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
...
...
@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl(
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl(
@overload
(
_qr_r_pivot
)
def
qr_r_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
def
qr_r_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
...
...
@@ -658,7 +646,6 @@ def qr_r_pivot_impl(
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
@@ -720,9 +707,7 @@ def qr_r_pivot_impl(
@overload
(
_qr_r_no_pivot
)
def
qr_r_no_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
def
qr_r_no_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
...
...
@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl(
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl(
@overload
(
_qr_raw_no_pivot
)
def
qr_raw_no_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
def
qr_raw_no_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
...
...
@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl(
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl(
@overload
(
_qr_raw_pivot
)
def
qr_raw_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
def
qr_raw_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
...
...
@@ -880,7 +859,6 @@ def qr_raw_pivot_impl(
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
浏览文件 @
672a4829
...
...
@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
def
_cho_solve
(
C
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
):
def
_cho_solve
(
C
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_b
:
bool
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
return
linalg
.
cho_solve
(
(
C
,
lower
),
b
=
B
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
(
C
,
lower
),
b
=
B
,
overwrite_b
=
overwrite_b
,
check_finite
=
False
,
)
@overload
(
_cho_solve
)
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
):
ensure_lapack
()
_check_linalg_matrix
(
C
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cho_solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"cho_solve"
)
...
...
@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
dtype
=
C
.
dtype
numba_potrs
=
_LAPACK
()
.
numba_xpotrs
(
dtype
)
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
):
_solve_check_input_shapes
(
C
,
B
)
_N
=
np
.
int32
(
C
.
shape
[
-
1
])
...
...
@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
INFO
,
)
_solve_check
(
_N
,
int_ptr_to_val
(
INFO
))
if
int_ptr_to_val
(
INFO
)
!=
0
:
B_copy
=
np
.
full_like
(
B_copy
,
np
.
nan
)
if
B_is_1d
:
return
B_copy
[
...
,
0
]
...
...
pytensor/link/numba/dispatch/linalg/solve/general.py
浏览文件 @
672a4829
...
...
@@ -3,82 +3,24 @@ from collections.abc import Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.solve.lu_solve
import
_getrs
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_linalg_matrix
,
_solve_check
,
)
def
_xgecon
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return
# type: ignore
@overload
(
_xgecon
)
def
xgecon_impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
Callable
[[
np
.
ndarray
,
float
,
str
],
tuple
[
np
.
ndarray
,
int
]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"gecon"
)
dtype
=
A
.
dtype
numba_gecon
=
_LAPACK
()
.
numba_xgecon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
A_NORM
=
np
.
array
(
A_norm
,
dtype
=
dtype
)
NORM
=
val_to_int_ptr
(
ord
(
norm
))
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
4
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
1
)
numba_gecon
(
NORM
,
N
,
A_copy
.
ctypes
,
LDA
,
A_NORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_gen
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
...
...
@@ -89,7 +31,7 @@ def _solve_gen(
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
assume_a
=
"gen"
,
transposed
=
transposed
,
)
...
...
@@ -102,9 +44,8 @@ def solve_gen_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
...
@@ -116,7 +57,6 @@ def solve_gen_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
...
...
@@ -127,20 +67,18 @@ def solve_gen_impl(
A
=
A
.
T
transposed
=
not
transposed
order
=
"I"
if
transposed
else
"1"
norm
=
_xlange
(
A
,
order
=
order
)
N
=
A
.
shape
[
1
]
LU
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
_solve_check
(
N
,
INFO
)
LU
,
IPIV
,
INFO1
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
X
,
INFO
=
_getrs
(
LU
=
LU
,
B
=
B
,
IPIV
=
IPIV
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
X
,
INFO2
=
_getrs
(
LU
=
LU
,
B
=
B
,
IPIV
=
IPIV
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
,
)
_solve_check
(
N
,
INFO
)
RCOND
,
INFO
=
_xgecon
(
LU
,
norm
,
"1"
)
_solve_check
(
N
,
INFO
,
True
,
RCOND
)
if
INFO1
!=
0
or
INFO2
!=
0
:
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
...
...
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
浏览文件 @
672a4829
...
...
@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
...
...
@@ -107,14 +106,11 @@ def _lu_solve(
b
:
np
.
ndarray
,
trans
:
_Trans
,
overwrite_b
:
bool
,
check_finite
:
bool
,
):
"""
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
"""
return
linalg
.
lu_solve
(
lu_and_piv
,
b
,
trans
=
trans
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
)
return
linalg
.
lu_solve
(
lu_and_piv
,
b
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
@overload
(
_lu_solve
)
...
...
@@ -123,8 +119,7 @@ def lu_solve_impl(
b
:
np
.
ndarray
,
trans
:
_Trans
,
overwrite_b
:
bool
,
check_finite
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
lu
,
_piv
=
lu_and_piv
_check_linalg_matrix
(
lu
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_solve"
)
...
...
@@ -137,13 +132,11 @@ def lu_solve_impl(
b
:
np
.
ndarray
,
trans
:
_Trans
,
overwrite_b
:
bool
,
check_finite
:
bool
,
)
->
np
.
ndarray
:
n
=
np
.
int32
(
lu
.
shape
[
0
]
)
X
,
info
=
_getrs
(
LU
=
lu
,
B
=
b
,
IPIV
=
piv
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
X
,
INFO
=
_getrs
(
LU
=
lu
,
B
=
b
,
IPIV
=
piv
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
_solve_check
(
n
,
INFO
)
if
info
!=
0
:
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
...
...
pytensor/link/numba/dispatch/linalg/solve/norm.py
deleted
100644 → 0
浏览文件 @
b2d8bc24
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_xlange
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
float
:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return
# type: ignore
@overload
(
_xlange
)
def
xlange_impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
Callable
[[
np
.
ndarray
,
str
],
float
]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"norm"
)
dtype
=
A
.
dtype
numba_lange
=
_LAPACK
()
.
numba_xlange
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
):
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
NORM
=
(
val_to_int_ptr
(
ord
(
order
))
if
order
is
not
None
else
val_to_int_ptr
(
ord
(
"1"
))
)
WORK
=
np
.
empty
(
_M
,
dtype
=
dtype
)
# type: ignore
result
=
numba_lange
(
NORM
,
M
,
N
,
A_copy
.
ctypes
,
LDA
,
WORK
.
ctypes
)
return
result
return
impl
pytensor/link/numba/dispatch/linalg/solve/posdef.py
浏览文件 @
672a4829
...
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
...
...
@@ -27,8 +25,6 @@ def _posv(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
...
...
@@ -43,10 +39,8 @@ def posv_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
],
]:
ensure_lapack
()
...
...
@@ -62,8 +56,6 @@ def posv_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_solve_check_input_shapes
(
A
,
B
)
...
...
@@ -115,60 +107,12 @@ def posv_impl(
return
impl
def
_pocon
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return
# type: ignore
@overload
(
_pocon
)
def
pocon_impl
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"pocon"
)
dtype
=
A
.
dtype
numba_pocon
=
_LAPACK
()
.
numba_xpocon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
anorm
:
float
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
))
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
3
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_pocon
(
UPLO
,
N
,
A_copy
.
ctypes
,
LDA
,
ANORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_psd
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
...
...
@@ -179,7 +123,7 @@ def _solve_psd(
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
transposed
=
transposed
,
assume_a
=
"pos"
,
)
...
...
@@ -192,9 +136,8 @@ def solve_psd_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
...
@@ -206,18 +149,14 @@ def solve_psd_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
C
,
x
,
info
=
_posv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
,
check_finite
,
transposed
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
_C
,
x
,
info
=
_posv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
rcond
,
info
=
_pocon
(
C
,
_xlange
(
A
))
_solve_check
(
A
.
shape
[
-
1
],
info
=
info
,
lamch
=
True
,
rcond
=
rcond
)
if
info
!=
0
:
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
...
...
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
浏览文件 @
672a4829
...
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
...
...
@@ -121,61 +119,12 @@ def sysv_impl(
return
impl
def
_sycon
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return
# type: ignore
@overload
(
_sycon
)
def
sycon_impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"sycon"
)
dtype
=
A
.
dtype
numba_sycon
=
_LAPACK
()
.
numba_xsycon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
UPLO
=
val_to_int_ptr
(
ord
(
"U"
))
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
2
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_sycon
(
UPLO
,
N
,
A_copy
.
ctypes
,
LDA
,
ipiv
.
ctypes
,
ANORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_symmetric
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
...
...
@@ -186,7 +135,7 @@ def _solve_symmetric(
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
assume_a
=
"sym"
,
transposed
=
transposed
,
)
...
...
@@ -199,9 +148,8 @@ def solve_symmetric_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
...
@@ -213,16 +161,14 @@ def solve_symmetric_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
lu
,
x
,
ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
_lu
,
x
,
_ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
rcond
,
info
=
_sycon
(
lu
,
ipiv
,
_xlange
(
A
,
order
=
"I"
))
_solve_check
(
A
.
shape
[
-
1
],
info
,
True
,
rcond
)
if
info
!=
0
:
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
...
...
pytensor/link/numba/dispatch/linalg/solve/triangular.py
浏览文件 @
672a4829
...
...
@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
def
_solve_triangular
(
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
b_ndim
=
1
,
overwrite_b
=
False
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
overwrite_b
=
False
):
"""
Thin wrapper around scipy.linalg.solve_triangular.
...
...
@@ -39,11 +38,12 @@ def _solve_triangular(
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
check_finite
=
False
,
)
@overload
(
_solve_triangular
)
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
overwrite_b
):
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve_triangular"
)
...
...
@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
"This function is not expected to work with complex numbers yet"
)
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
overwrite_b
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d
=
B
.
ndim
==
1
if
A
.
flags
.
f_contiguous
or
(
A
.
flags
.
c_contiguous
and
trans
in
(
0
,
1
)):
...
...
@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
LDB
,
INFO
,
)
_solve_check
(
int_ptr_to_val
(
LDA
),
int_ptr_to_val
(
INFO
)
)
if
int_ptr_to_val
(
INFO
)
!=
0
:
B_copy
=
np
.
full_like
(
B_copy
,
np
.
nan
)
if
B_is_1d
:
return
B_copy
[
...
,
0
]
...
...
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
浏览文件 @
672a4829
...
...
@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
...
...
@@ -202,83 +201,12 @@ def gttrs_impl(
return
impl
def
_gtcon
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
tuple
[
ndarray
,
int
]:
"""Placeholder for computing the condition number of a tridiagonal system."""
return
# type: ignore
@overload
(
_gtcon
)
def
gtcon_impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
Callable
[
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
float
,
str
],
tuple
[
ndarray
,
int
]
]:
ensure_lapack
()
_check_linalg_matrix
(
dl
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
d
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
du
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
du2
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_dtypes_match
((
dl
,
d
,
du
,
du2
),
func_name
=
"gtcon"
)
_check_linalg_matrix
(
ipiv
,
ndim
=
1
,
dtype
=
int32
,
func_name
=
"gtcon"
)
dtype
=
d
.
dtype
numba_gtcon
=
_LAPACK
()
.
numba_xgtcon
(
dtype
)
def
impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
tuple
[
ndarray
,
int
]:
n
=
np
.
int32
(
d
.
shape
[
-
1
])
rcond
=
np
.
empty
(
1
,
dtype
=
dtype
)
work
=
np
.
empty
(
2
*
n
,
dtype
=
dtype
)
iwork
=
np
.
empty
(
n
,
dtype
=
np
.
int32
)
info
=
val_to_int_ptr
(
0
)
numba_gtcon
(
val_to_int_ptr
(
ord
(
norm
)),
val_to_int_ptr
(
n
),
dl
.
ctypes
,
d
.
ctypes
,
du
.
ctypes
,
du2
.
ctypes
,
ipiv
.
ctypes
,
np
.
array
(
anorm
,
dtype
=
dtype
)
.
ctypes
,
rcond
.
ctypes
,
work
.
ctypes
,
iwork
.
ctypes
,
info
,
)
return
rcond
,
int_ptr_to_val
(
info
)
return
impl
def
_solve_tridiagonal
(
a
:
ndarray
,
b
:
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""
...
...
@@ -290,7 +218,7 @@ def _solve_tridiagonal(
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
transposed
=
transposed
,
assume_a
=
"tridiagonal"
,
)
...
...
@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
ndarray
,
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
ndarray
]:
)
->
Callable
[[
ndarray
,
ndarray
,
bool
,
bool
,
bool
,
bool
],
ndarray
]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
...
@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
ndarray
:
n
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
norm
=
"1"
if
transposed
:
A
=
A
.
T
dl
,
d
,
du
=
np
.
diag
(
A
,
-
1
),
np
.
diag
(
A
,
0
),
np
.
diag
(
A
,
1
)
anorm
=
tridiagonal_norm
(
du
,
d
,
dl
)
dl
,
d
,
du
,
du2
,
IPIV
,
INFO
=
_gttrf
(
dl
,
d
,
du
,
du2
,
ipiv
,
info1
=
_gttrf
(
dl
,
d
,
du
,
overwrite_dl
=
True
,
overwrite_d
=
True
,
overwrite_du
=
True
)
_solve_check
(
n
,
INFO
)
X
,
INFO
=
_gttrs
(
dl
,
d
,
du
,
du2
,
IPIV
,
B
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
X
,
info2
=
_gttrs
(
dl
,
d
,
du
,
du2
,
ipiv
,
B
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
)
_solve_check
(
n
,
INFO
)
RCOND
,
INFO
=
_gtcon
(
dl
,
d
,
du
,
du2
,
IPIV
,
anorm
,
norm
)
_solve_check
(
n
,
INFO
,
True
,
RCOND
)
if
info1
!=
0
or
info2
!=
0
:
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
...
...
@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
)
return
dl
,
d
,
du
,
du2
,
ipiv
cache_
key
=
1
return
lu_factor_tridiagonal
,
cache_
key
cache_
version
=
2
return
lu_factor_tridiagonal
,
cache_
version
@register_funcify_default_op_cache_key
(
SolveLUFactorTridiagonal
)
...
...
@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
ipiv
=
ipiv
.
astype
(
np
.
int32
)
if
cast_b
:
b
=
b
.
astype
(
out_dtype
)
x
,
_
=
_gttrs
(
x
,
info
=
_gttrs
(
dl
,
d
,
du
,
...
...
@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b
=
overwrite_b
,
trans
=
transposed
,
)
if
info
!=
0
:
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
cache_
key
=
1
return
solve_lu_factor_tridiagonal
,
cache_
key
cache_
version
=
2
return
solve_lu_factor_tridiagonal
,
cache_
version
pytensor/link/numba/dispatch/linalg/utils.py
浏览文件 @
672a4829
from
collections.abc
import
Callable
,
Sequence
from
collections.abc
import
Sequence
import
numba
from
numba.core
import
types
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numpy.linalg
import
LinAlgError
from
numba.np.linalg
import
_copy_to_fortran_order
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
val_to_int_ptr
,
)
@numba_basic.numba_njit
(
inline
=
"always"
)
...
...
@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
if
first_dtype
!=
other_dtype
:
msg
=
f
"{func_name} only supported for matching dtypes, got {dtypes}"
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_solve_check
(
n
,
info
,
lamch
=
False
,
rcond
=
None
):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if
info
<
0
:
# TODO: figure out how to do an fstring here
msg
=
"LAPACK reported an illegal value in input"
raise
ValueError
(
msg
)
elif
0
<
info
:
raise
LinAlgError
(
"Matrix is singular."
)
if
lamch
:
E
=
_xlamch
(
"E"
)
if
rcond
<
E
:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print
(
# noqa: T201
"Ill-conditioned matrix, rcond="
,
rcond
,
", result may not be accurate."
)
def
_xlamch
(
kind
:
str
=
"E"
):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload
(
_xlamch
)
def
xlamch_impl
(
kind
:
str
=
"E"
)
->
Callable
[[
str
],
float
]:
"""
Compute the machine precision for a given floating point type.
"""
from
pytensor
import
config
ensure_lapack
()
w_type
=
_get_underlying_float
(
config
.
floatX
)
if
w_type
==
"float32"
:
dtype
=
types
.
float32
elif
w_type
==
"float64"
:
dtype
=
types
.
float64
else
:
raise
NotImplementedError
(
"Unsupported dtype"
)
numba_lamch
=
_LAPACK
()
.
numba_xlamch
(
dtype
)
def
impl
(
kind
:
str
=
"E"
)
->
float
:
KIND
=
val_to_int_ptr
(
ord
(
kind
))
return
numba_lamch
(
KIND
)
# type: ignore
return
impl
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
672a4829
差异被折叠。
点击展开。
pytensor/link/numba/dispatch/tensor_basic.py
浏览文件 @
672a4829
...
...
@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
out
[
...
,
i
]
=
new_entry
return
out
cache_
key
=
1
return
extract_diag
,
cache_
key
cache_
version
=
1
return
extract_diag
,
cache_
version
@register_funcify_default_op_cache_key
(
Eye
)
...
...
pytensor/tensor/_linalg/solve/rewriting.py
浏览文件 @
672a4829
...
...
@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so
from
pytensor.tensor.variable
import
TensorVariable
def
decompose_A
(
A
,
assume_a
,
check_finite
,
lower
):
def
decompose_A
(
A
,
assume_a
,
lower
):
if
assume_a
==
"gen"
:
return
lu_factor
(
A
,
check_finite
=
check_finite
)
return
lu_factor
(
A
)
elif
assume_a
==
"tridiagonal"
:
# We didn't implement check_finite for tridiagonal LU factorization
return
tridiagonal_lu_factor
(
A
)
elif
assume_a
==
"pos"
:
return
cholesky
(
A
,
lower
=
lower
,
check_finite
=
check_finite
)
return
cholesky
(
A
,
lower
=
lower
)
else
:
raise
NotImplementedError
...
...
@@ -36,7 +35,6 @@ def solve_decomposed_system(
A_decomp
,
b
,
transposed
=
False
,
lower
=
False
,
*
,
core_solve_op
:
Solve
):
b_ndim
=
core_solve_op
.
b_ndim
check_finite
=
core_solve_op
.
check_finite
assume_a
=
core_solve_op
.
assume_a
if
assume_a
==
"gen"
:
...
...
@@ -45,10 +43,8 @@ def solve_decomposed_system(
b
,
b_ndim
=
b_ndim
,
trans
=
transposed
,
check_finite
=
check_finite
,
)
elif
assume_a
==
"tridiagonal"
:
# We didn't implement check_finite for tridiagonal LU solve
return
tridiagonal_lu_solve
(
A_decomp
,
b
,
...
...
@@ -61,7 +57,6 @@ def solve_decomposed_system(
(
A_decomp
,
lower
),
b
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
,
)
else
:
raise
NotImplementedError
...
...
@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps(
):
return
None
# If any Op had check_finite=True, we also do it for the LU decomposition
check_finite_decomp
=
False
for
client
,
_
in
A_solve_clients_and_transpose
:
if
client
.
op
.
core_op
.
check_finite
:
check_finite_decomp
=
True
break
lower
=
node
.
op
.
core_op
.
lower
A_decomp
=
decompose_A
(
A
,
assume_a
=
assume_a
,
check_finite
=
check_finite_decomp
,
lower
=
lower
)
A_decomp
=
decompose_A
(
A
,
assume_a
=
assume_a
,
lower
=
lower
)
replacements
=
{}
for
client
,
transposed
in
A_solve_clients_and_transpose
:
...
...
pytensor/tensor/slinalg.py
浏览文件 @
672a4829
差异被折叠。
点击展开。
pytensor/xtensor/linalg.py
浏览文件 @
672a4829
from
collections.abc
import
Sequence
from
typing
import
Literal
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
from
pytensor.xtensor.type
import
as_xtensor
...
...
@@ -10,8 +9,7 @@ def cholesky(
x
,
lower
:
bool
=
True
,
*
,
check_finite
:
bool
=
False
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"raise"
,
check_finite
:
bool
=
True
,
dims
:
Sequence
[
str
],
):
"""Compute the Cholesky decomposition of an XTensorVariable.
...
...
@@ -22,22 +20,15 @@ def cholesky(
The input variable to decompose.
lower : bool, optional
Whether to return the lower triangular matrix. Default is True.
check_finite : bool, optional
Whether to check that the input is finite. Default is False.
on_error : {'raise', 'nan'}, optional
What to do if the input is not positive definite. If 'raise', an error is raised.
If 'nan', the output will contain NaNs. Default is 'raise'.
check_finite : bool
Unused by PyTensor. PyTensor will return nan if the operation fails.
dims : Sequence[str]
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
"""
if
len
(
dims
)
!=
2
:
raise
ValueError
(
f
"Cholesky needs two dims, got {len(dims)}"
)
core_op
=
Cholesky
(
lower
=
lower
,
check_finite
=
check_finite
,
on_error
=
on_error
,
)
core_op
=
Cholesky
(
lower
=
lower
)
core_dims
=
(
((
dims
[
0
],
dims
[
1
]),),
((
dims
[
0
],
dims
[
1
]),),
...
...
@@ -52,7 +43,7 @@ def solve(
dims
:
Sequence
[
str
],
assume_a
=
"gen"
,
lower
:
bool
=
False
,
check_finite
:
bool
=
Fals
e
,
check_finite
:
bool
=
Tru
e
,
):
"""Solve a system of linear equations using XTensorVariables.
...
...
@@ -75,8 +66,8 @@ def solve(
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
lower : bool, optional
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
check_finite : bool
, optional
Whether to check that the input is finite. Default is False
.
check_finite : bool
Unused by PyTensor. PyTensor will return nan if the operation fails
.
"""
a
,
b
=
as_xtensor
(
a
),
as_xtensor
(
b
)
input_core_dims
:
tuple
[
tuple
[
str
,
str
],
tuple
[
str
]
|
tuple
[
str
,
str
]]
...
...
@@ -98,9 +89,7 @@ def solve(
else
:
raise
ValueError
(
"Solve dims must have length 2 or 3"
)
core_op
=
Solve
(
b_ndim
=
b_ndim
,
assume_a
=
assume_a
,
lower
=
lower
,
check_finite
=
check_finite
)
core_op
=
Solve
(
b_ndim
=
b_ndim
,
assume_a
=
assume_a
,
lower
=
lower
)
x_op
=
XBlockwise
(
core_op
,
core_dims
=
(
input_core_dims
,
output_core_dims
),
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
672a4829
import
re
from
typing
import
Literal
import
numpy
as
np
...
...
@@ -36,70 +35,6 @@ floatX = config.floatX
rng
=
np
.
random
.
default_rng
(
42849
)
def
test_lamch
():
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.linalg.utils
import
_xlamch
@numba.njit
()
def
xlamch
(
kind
):
return
_xlamch
(
kind
)
lamch
=
get_lapack_funcs
(
"lamch"
,
(
np
.
array
([
0.0
],
dtype
=
floatX
),))
np
.
testing
.
assert_allclose
(
xlamch
(
"E"
),
lamch
(
"E"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"S"
),
lamch
(
"S"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"P"
),
lamch
(
"P"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"B"
),
lamch
(
"B"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"R"
),
lamch
(
"R"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"M"
),
lamch
(
"M"
))
@pytest.mark.parametrize
(
"ord_numba, ord_scipy"
,
[(
"F"
,
"fro"
),
(
"1"
,
1
),
(
"I"
,
np
.
inf
)]
)
def
test_xlange
(
ord_numba
,
ord_scipy
):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
@numba.njit
()
def
xlange
(
x
,
ord
):
return
_xlange
(
x
,
ord
)
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
np
.
testing
.
assert_allclose
(
xlange
(
x
,
ord_numba
),
linalg
.
norm
(
x
,
ord_scipy
))
@pytest.mark.parametrize
(
"ord_numba, ord_scipy"
,
[(
"1"
,
1
),
(
"I"
,
np
.
inf
)])
def
test_xgecon
(
ord_numba
,
ord_scipy
):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_xgecon
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
@numba.njit
()
def
gecon
(
x
,
norm
):
anorm
=
_xlange
(
x
,
norm
)
cond
,
info
=
_xgecon
(
x
,
anorm
,
norm
)
return
cond
,
info
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
rcond
,
info
=
gecon
(
x
,
norm
=
ord_numba
)
# Test against direct call to the underlying LAPACK functions
# Solution does **not** agree with 1 / np.linalg.cond(x) !
lange
,
gecon
=
get_lapack_funcs
((
"lange"
,
"gecon"
),
(
x
,))
norm
=
lange
(
ord_numba
,
x
)
rcond2
,
_
=
gecon
(
x
,
norm
,
norm
=
ord_numba
)
assert
info
==
0
np
.
testing
.
assert_allclose
(
rcond
,
rcond2
)
class
TestSolves
:
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower={x}"
)
@pytest.mark.parametrize
(
...
...
@@ -323,7 +258,7 @@ class TestSolves:
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
@pytest.mark.parametrize
(
"value"
,
[
np
.
nan
,
np
.
inf
])
def
test_solve_triangular_
raises
_on_nan_inf
(
self
,
value
):
def
test_solve_triangular_
does_not_raise
_on_nan_inf
(
self
,
value
):
A
=
pt
.
matrix
(
"A"
)
b
=
pt
.
matrix
(
"b"
)
...
...
@@ -335,11 +270,8 @@ class TestSolves:
A_tri
=
np
.
linalg
.
cholesky
(
A_sym
)
.
astype
(
floatX
)
b
=
np
.
full
((
5
,
1
),
value
)
.
astype
(
floatX
)
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
re
.
escape
(
"Non-numeric values"
),
):
f
(
A_tri
,
b
)
# Not checking everything is nan, because, with inf, LAPACK returns a mix of inf/nan, but does not set info != 0
assert
not
np
.
isfinite
(
f
(
A_tri
,
b
))
.
any
()
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower = {x}"
)
@pytest.mark.parametrize
(
...
...
@@ -567,10 +499,13 @@ class TestDecompositions:
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
x
=
x
.
T
.
dot
(
x
)
g
=
pt
.
linalg
.
cholesky
(
x
,
check_finite
=
True
)
with
pytest
.
warns
(
FutureWarning
):
g
=
pt
.
linalg
.
cholesky
(
x
,
check_finite
=
True
,
on_error
=
"raise"
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"Non-numeric values"
):
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"Matrix is not positive definite"
):
f
(
test_value
)
@pytest.mark.parametrize
(
"on_error"
,
[
"nan"
,
"raise"
])
...
...
@@ -578,13 +513,17 @@ class TestDecompositions:
test_value
=
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
if
on_error
==
"raise"
:
with
pytest
.
warns
(
FutureWarning
):
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
else
:
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
if
on_error
==
"raise"
:
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"
Input to cholesky
is not positive definite"
,
match
=
r"
Matrix
is not positive definite"
,
):
f
(
test_value
)
else
:
...
...
tests/tensor/linalg/test_rewriting.py
浏览文件 @
672a4829
...
...
@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
resx1
=
fn_opt
(
A_test
,
x0_test
)
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-4
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
@pytest.mark.parametrize
(
"assume_a, counter"
,
(
(
"gen"
,
LUOpCounter
),
(
"pos"
,
CholeskyOpCounter
),
),
)
def
test_decomposition_reused_preserves_check_finite
(
assume_a
,
counter
):
# Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name
=
reuse_decomposition_multiple_solves
.
__name__
A
=
tensor
(
"A"
,
shape
=
(
2
,
2
))
b1
=
tensor
(
"b1"
,
shape
=
(
2
,))
b2
=
tensor
(
"b2"
,
shape
=
(
2
,))
x1
=
solve
(
A
,
b1
,
assume_a
=
assume_a
,
check_finite
=
True
)
x2
=
solve
(
A
,
b2
,
assume_a
=
assume_a
,
check_finite
=
False
)
fn_opt
=
function
(
[
A
,
b1
,
b2
],
[
x1
,
x2
],
mode
=
get_default_mode
()
.
including
(
rewrite_name
)
)
opt_nodes
=
fn_opt
.
maker
.
fgraph
.
apply_nodes
assert
counter
.
count_vanilla_solve_nodes
(
opt_nodes
)
==
0
assert
counter
.
count_decomp_nodes
(
opt_nodes
)
==
1
assert
counter
.
count_solve_nodes
(
opt_nodes
)
==
2
# We should get an error if A or b1 is non finite
A_valid
=
np
.
array
([[
1
,
0
],
[
0
,
1
]],
dtype
=
A
.
type
.
dtype
)
b1_valid
=
np
.
array
([
1
,
1
],
dtype
=
b1
.
type
.
dtype
)
b2_valid
=
np
.
array
([
1
,
1
],
dtype
=
b2
.
type
.
dtype
)
assert
fn_opt
(
A_valid
,
b1_valid
,
b2_valid
)
# Fine
assert
fn_opt
(
A_valid
,
b1_valid
,
b2_valid
*
np
.
nan
)
# Should not raise (also fine on most LAPACK implementations?)
err_msg
=
(
"(array must not contain infs or NaNs"
r"|Non-numeric values \(nan or inf\))"
)
with
pytest
.
raises
((
ValueError
,
np
.
linalg
.
LinAlgError
),
match
=
err_msg
):
assert
fn_opt
(
A_valid
,
b1_valid
*
np
.
nan
,
b2_valid
)
with
pytest
.
raises
((
ValueError
,
np
.
linalg
.
LinAlgError
),
match
=
err_msg
):
assert
fn_opt
(
A_valid
*
np
.
nan
,
b1_valid
,
b2_valid
)
tests/tensor/test_slinalg.py
浏览文件 @
672a4829
...
...
@@ -74,9 +74,6 @@ def test_cholesky():
chol
=
Cholesky
(
lower
=
False
)(
x
)
ch_f
=
function
([
x
],
chol
)
check_upper_triangular
(
pd
,
ch_f
)
chol
=
Cholesky
(
lower
=
False
,
on_error
=
"nan"
)(
x
)
ch_f
=
function
([
x
],
chol
)
check_upper_triangular
(
pd
,
ch_f
)
def
test_cholesky_performance
(
benchmark
):
...
...
@@ -102,12 +99,15 @@ def test_cholesky_empty():
def
test_cholesky_indef
():
x
=
matrix
()
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
cholesky
(
x
))
with
pytest
.
warns
(
FutureWarning
):
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
out
)
with
pytest
.
raises
(
scipy
.
linalg
.
LinAlgError
):
chol_f
(
mat
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
cholesky
(
x
))
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
out
)
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
...
...
@@ -143,12 +143,16 @@ def test_cholesky_grad():
def
test_cholesky_grad_indef
():
x
=
matrix
()
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
grad
(
cholesky
(
x
)
.
sum
(),
[
x
]))
with
pytest
.
raises
(
scipy
.
linalg
.
LinAlgError
):
chol_f
(
mat
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
grad
(
cholesky
(
x
)
.
sum
(),
[
x
]))
with
pytest
.
warns
(
FutureWarning
):
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
grad
(
out
.
sum
(),
[
x
]),
mode
=
"FAST_RUN"
)
# original cholesky doesn't show up in the grad (if mode="FAST_RUN"), so it does not raise
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
grad
(
out
.
sum
(),
[
x
]))
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
...
...
@@ -237,7 +241,7 @@ class TestSolveBase:
y
=
self
.
SolveTest
(
b_ndim
=
2
)(
A
,
b
)
assert
(
y
.
__repr__
()
==
"SolveTest{lower=False,
check_finite=True,
b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
==
"SolveTest{lower=False, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
...
...
@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def
test_repr
(
self
):
assert
(
repr
(
CholeskySolve
(
lower
=
True
,
b_ndim
=
1
))
==
"CholeskySolve(lower=True,
check_finite=True,
b_ndim=1,overwrite_b=False)"
==
"CholeskySolve(lower=True,b_ndim=1,overwrite_b=False)"
)
def
test_infer_shape
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论