Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
672a4829
提交
672a4829
authored
1月 08, 2026
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
1月 11, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Do not raise in linalg Ops
上级
b2d8bc24
隐藏空白字符变更
内嵌
并排
正在显示
23 个修改的文件
包含
226 行增加
和
1132 行删除
+226
-1132
slinalg.py
pytensor/link/jax/dispatch/slinalg.py
+4
-10
_LAPACK.py
pytensor/link/numba/dispatch/linalg/_LAPACK.py
+0
-201
cholesky.py
...nsor/link/numba/dispatch/linalg/decomposition/cholesky.py
+11
-13
lu.py
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
+6
-15
lu_factor.py
...sor/link/numba/dispatch/linalg/decomposition/lu_factor.py
+4
-3
qr.py
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
+12
-34
cholesky.py
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
+9
-8
general.py
pytensor/link/numba/dispatch/linalg/solve/general.py
+12
-74
lu_solve.py
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
+5
-12
norm.py
pytensor/link/numba/dispatch/linalg/solve/norm.py
+0
-55
posdef.py
pytensor/link/numba/dispatch/linalg/solve/posdef.py
+6
-67
symmetric.py
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
+5
-59
triangular.py
pytensor/link/numba/dispatch/linalg/solve/triangular.py
+6
-8
tridiagonal.py
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
+16
-92
utils.py
pytensor/link/numba/dispatch/linalg/utils.py
+2
-64
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+23
-111
tensor_basic.py
pytensor/link/numba/dispatch/tensor_basic.py
+2
-2
rewriting.py
pytensor/tensor/_linalg/solve/rewriting.py
+4
-18
slinalg.py
pytensor/tensor/slinalg.py
+58
-133
linalg.py
pytensor/xtensor/linalg.py
+8
-19
test_slinalg.py
tests/link/numba/test_slinalg.py
+14
-75
test_rewriting.py
tests/tensor/linalg/test_rewriting.py
+0
-44
test_slinalg.py
tests/tensor/test_slinalg.py
+19
-15
没有找到文件。
pytensor/link/jax/dispatch/slinalg.py
浏览文件 @
672a4829
...
...
@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs):
def
jax_funcify_SolveTriangular
(
op
,
**
kwargs
):
lower
=
op
.
lower
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
def
solve_triangular
(
A
,
b
):
return
jax
.
scipy
.
linalg
.
solve_triangular
(
...
...
@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
lower
=
lower
,
trans
=
0
,
# this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal
=
unit_diagonal
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
)
return
solve_triangular
...
...
@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs):
def
jax_funcify_LU
(
op
,
**
kwargs
):
permute_l
=
op
.
permute_l
p_indices
=
op
.
p_indices
check_finite
=
op
.
check_finite
if
p_indices
:
raise
ValueError
(
"JAX does not support the p_indices argument"
)
def
lu
(
*
inputs
):
return
jax
.
scipy
.
linalg
.
lu
(
*
inputs
,
permute_l
=
permute_l
,
check_finite
=
check_finite
)
return
jax
.
scipy
.
linalg
.
lu
(
*
inputs
,
permute_l
=
permute_l
,
check_finite
=
False
)
return
lu
@jax_funcify.register
(
LUFactor
)
def
jax_funcify_LUFactor
(
op
,
**
kwargs
):
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
def
lu_factor
(
a
):
return
jax
.
scipy
.
linalg
.
lu_factor
(
a
,
check_finite
=
check_finit
e
,
overwrite_a
=
overwrite_a
a
,
check_finite
=
Fals
e
,
overwrite_a
=
overwrite_a
)
return
lu_factor
...
...
@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs):
@jax_funcify.register
(
CholeskySolve
)
def
jax_funcify_ChoSolve
(
op
,
**
kwargs
):
lower
=
op
.
lower
check_finite
=
op
.
check_finite
overwrite_b
=
op
.
overwrite_b
def
cho_solve
(
c
,
b
):
return
jax
.
scipy
.
linalg
.
cho_solve
(
(
c
,
lower
),
b
,
check_finite
=
check_finit
e
,
overwrite_b
=
overwrite_b
(
c
,
lower
),
b
,
check_finite
=
Fals
e
,
overwrite_b
=
overwrite_b
)
return
cho_solve
...
...
pytensor/link/numba/dispatch/linalg/_LAPACK.py
浏览文件 @
672a4829
...
...
@@ -263,122 +263,6 @@ class _LAPACK:
return
potrs
@classmethod
def
numba_xlange
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve, but doesn't correspond to any Op in pytensor.
"""
kind
=
get_blas_kind
(
dtype
)
float_type
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
False
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
True
)
unique_func_name
=
f
"scipy.lapack.{kind}lange"
@numba_basic.numba_njit
def
get_lange_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"lange"
)
return
ptr
lange_function_type
=
types
.
FunctionType
(
float_type
(
nb_i32p
,
# NORM
nb_i32p
,
# M
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# WORK
)
)
@numba_basic.numba_njit
def
lange
(
NORM
,
M
,
N
,
A
,
LDA
,
WORK
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_lange_pointer
,
func_type_ref
=
lange_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
return
fn
(
NORM
,
M
,
N
,
A
,
LDA
,
WORK
)
return
lange
@classmethod
def
numba_xlamch
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Determine machine precision for floating point arithmetic.
"""
kind
=
get_blas_kind
(
dtype
)
float_type
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
False
)
unique_func_name
=
f
"scipy.lapack.{kind}lamch"
@numba_basic.numba_njit
def
get_lamch_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"lamch"
)
return
ptr
lamch_function_type
=
types
.
FunctionType
(
float_type
(
# Return type
nb_i32p
,
# CMACH
)
)
@numba_basic.numba_njit
def
lamch
(
CMACH
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_lamch_pointer
,
func_type_ref
=
lamch_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
res
=
fn
(
CMACH
)
return
res
return
lamch
@classmethod
def
numba_xgecon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}gecon"
@numba_basic.numba_njit
def
get_gecon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"gecon"
)
return
ptr
gecon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# NORM
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
gecon
(
NORM
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_gecon_pointer
,
func_type_ref
=
gecon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
NORM
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
gecon
@classmethod
def
numba_xgetrf
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
...
...
@@ -506,91 +390,6 @@ class _LAPACK:
return
sysv
@classmethod
def
numba_xsycon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}sycon"
@numba_basic.numba_njit
def
get_sycon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"sycon"
)
return
ptr
sycon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# UPLO
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
nb_i32p
,
# IPIV
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
sycon
(
UPLO
,
N
,
A
,
LDA
,
IPIV
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_sycon_pointer
,
func_type_ref
=
sycon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
UPLO
,
N
,
A
,
LDA
,
IPIV
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
sycon
@classmethod
def
numba_xpocon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}pocon"
@numba_basic.numba_njit
def
get_pocon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"pocon"
)
return
ptr
pocon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# UPLO
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
pocon
(
UPLO
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_pocon_pointer
,
func_type_ref
=
pocon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
UPLO
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
pocon
@classmethod
def
numba_xposv
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
...
...
pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py
浏览文件 @
672a4829
...
...
@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
return
(
linalg
.
cholesky
(
a
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
),
0
,
)
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
):
return
linalg
.
cholesky
(
a
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
check_finite
=
False
)
@overload
(
_cholesky
)
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
,
check_finite
=
True
):
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
):
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cholesky"
)
dtype
=
A
.
dtype
numba_potrf
=
_LAPACK
()
.
numba_xpotrf
(
dtype
)
def
impl
(
A
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
def
impl
(
A
,
lower
=
False
,
overwrite_a
=
False
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
if
A
.
shape
[
-
2
]
!=
_N
:
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
...
...
@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
INFO
,
)
if
int_ptr_to_val
(
INFO
)
!=
0
:
A_copy
=
np
.
full_like
(
A_copy
,
np
.
nan
)
return
A_copy
if
lower
:
for
j
in
range
(
1
,
_N
):
for
i
in
range
(
j
):
...
...
@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
for
i
in
range
(
j
+
1
,
_N
):
A_copy
[
i
,
j
]
=
0.0
info_int
=
int_ptr_to_val
(
INFO
)
if
transposed
:
return
A_copy
.
T
,
info_int
return
A_copy
,
info_int
return
A_copy
.
T
else
:
return
A_copy
return
impl
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
浏览文件 @
672a4829
...
...
@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
def
_lu_1
(
a
:
np
.
ndarray
,
permute_l
:
Literal
[
True
],
check_finite
:
bool
,
p_indices
:
Literal
[
False
],
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
...
@@ -52,7 +51,7 @@ def _lu_1(
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
...
...
@@ -61,7 +60,6 @@ def _lu_1(
def
_lu_2
(
a
:
np
.
ndarray
,
permute_l
:
Literal
[
False
],
check_finite
:
bool
,
p_indices
:
Literal
[
True
],
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
...
...
@@ -74,7 +72,7 @@ def _lu_2(
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
...
...
@@ -83,7 +81,6 @@ def _lu_2(
def
_lu_3
(
a
:
np
.
ndarray
,
permute_l
:
Literal
[
False
],
check_finite
:
bool
,
p_indices
:
Literal
[
False
],
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
...
@@ -96,7 +93,7 @@ def _lu_3(
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
...
...
@@ -106,11 +103,10 @@ def _lu_3(
def
lu_impl_1
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...
...
@@ -123,7 +119,6 @@ def lu_impl_1(
def
impl
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
...
@@ -137,10 +132,9 @@ def lu_impl_1(
def
lu_impl_2
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
)
->
Callable
[[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
...
...
@@ -153,7 +147,6 @@ def lu_impl_2(
def
impl
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
...
...
@@ -169,11 +162,10 @@ def lu_impl_2(
def
lu_impl_3
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...
...
@@ -186,7 +178,6 @@ def lu_impl_3(
def
impl
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
浏览文件 @
672a4829
...
...
@@ -79,11 +79,12 @@ def lu_factor_impl(
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_factor"
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
A_copy
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
A_copy
,
IPIV
,
info
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
IPIV
-=
1
# LAPACK uses 1-based indexing, convert to 0-based
if
INFO
!=
0
:
raise
np
.
linalg
.
LinAlgError
(
"LU decomposition failed"
)
if
info
!=
0
:
A_copy
=
np
.
full_like
(
A_copy
,
np
.
nan
)
return
A_copy
,
IPIV
return
impl
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
浏览文件 @
672a4829
...
...
@@ -228,7 +228,6 @@ def _qr_full_pivot(
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -243,7 +242,7 @@ def _qr_full_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -253,7 +252,6 @@ def _qr_full_no_pivot(
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -267,7 +265,7 @@ def _qr_full_no_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -277,7 +275,6 @@ def _qr_r_pivot(
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -291,7 +288,7 @@ def _qr_r_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -301,7 +298,6 @@ def _qr_r_no_pivot(
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -315,7 +311,7 @@ def _qr_r_no_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -325,7 +321,6 @@ def _qr_raw_no_pivot(
mode
:
Literal
[
"raw"
]
=
"raw"
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -339,7 +334,7 @@ def _qr_raw_no_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -351,7 +346,6 @@ def _qr_raw_pivot(
mode
:
Literal
[
"raw"
]
=
"raw"
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
...
...
@@ -365,7 +359,7 @@ def _qr_raw_pivot(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
)
...
...
@@ -373,9 +367,7 @@ def _qr_raw_pivot(
@overload
(
_qr_full_pivot
)
def
qr_full_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
def
qr_full_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
...
...
@@ -395,7 +387,6 @@ def qr_full_pivot_impl(
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
@@ -529,7 +520,7 @@ def qr_full_pivot_impl(
@overload
(
_qr_full_no_pivot
)
def
qr_full_no_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
x
,
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
...
...
@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl(
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl(
@overload
(
_qr_r_pivot
)
def
qr_r_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
def
qr_r_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
...
...
@@ -658,7 +646,6 @@ def qr_r_pivot_impl(
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
@@ -720,9 +707,7 @@ def qr_r_pivot_impl(
@overload
(
_qr_r_no_pivot
)
def
qr_r_no_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
def
qr_r_no_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
...
...
@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl(
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl(
@overload
(
_qr_raw_no_pivot
)
def
qr_raw_no_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
def
qr_raw_no_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
...
...
@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl(
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl(
@overload
(
_qr_raw_pivot
)
def
qr_raw_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
def
qr_raw_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
...
...
@@ -880,7 +859,6 @@ def qr_raw_pivot_impl(
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
浏览文件 @
672a4829
...
...
@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
def
_cho_solve
(
C
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
):
def
_cho_solve
(
C
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_b
:
bool
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
return
linalg
.
cho_solve
(
(
C
,
lower
),
b
=
B
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
(
C
,
lower
),
b
=
B
,
overwrite_b
=
overwrite_b
,
check_finite
=
False
,
)
@overload
(
_cho_solve
)
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
):
ensure_lapack
()
_check_linalg_matrix
(
C
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cho_solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"cho_solve"
)
...
...
@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
dtype
=
C
.
dtype
numba_potrs
=
_LAPACK
()
.
numba_xpotrs
(
dtype
)
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
):
_solve_check_input_shapes
(
C
,
B
)
_N
=
np
.
int32
(
C
.
shape
[
-
1
])
...
...
@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
INFO
,
)
_solve_check
(
_N
,
int_ptr_to_val
(
INFO
))
if
int_ptr_to_val
(
INFO
)
!=
0
:
B_copy
=
np
.
full_like
(
B_copy
,
np
.
nan
)
if
B_is_1d
:
return
B_copy
[
...
,
0
]
...
...
pytensor/link/numba/dispatch/linalg/solve/general.py
浏览文件 @
672a4829
...
...
@@ -3,82 +3,24 @@ from collections.abc import Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.solve.lu_solve
import
_getrs
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_linalg_matrix
,
_solve_check
,
)
def
_xgecon
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return
# type: ignore
@overload
(
_xgecon
)
def
xgecon_impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
Callable
[[
np
.
ndarray
,
float
,
str
],
tuple
[
np
.
ndarray
,
int
]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"gecon"
)
dtype
=
A
.
dtype
numba_gecon
=
_LAPACK
()
.
numba_xgecon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
A_NORM
=
np
.
array
(
A_norm
,
dtype
=
dtype
)
NORM
=
val_to_int_ptr
(
ord
(
norm
))
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
4
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
1
)
numba_gecon
(
NORM
,
N
,
A_copy
.
ctypes
,
LDA
,
A_NORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_gen
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
...
...
@@ -89,7 +31,7 @@ def _solve_gen(
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
assume_a
=
"gen"
,
transposed
=
transposed
,
)
...
...
@@ -102,9 +44,8 @@ def solve_gen_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
...
@@ -116,7 +57,6 @@ def solve_gen_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
...
...
@@ -127,20 +67,18 @@ def solve_gen_impl(
A
=
A
.
T
transposed
=
not
transposed
order
=
"I"
if
transposed
else
"1"
norm
=
_xlange
(
A
,
order
=
order
)
N
=
A
.
shape
[
1
]
LU
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
_solve_check
(
N
,
INFO
)
LU
,
IPIV
,
INFO1
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
X
,
INFO
=
_getrs
(
LU
=
LU
,
B
=
B
,
IPIV
=
IPIV
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
X
,
INFO2
=
_getrs
(
LU
=
LU
,
B
=
B
,
IPIV
=
IPIV
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
,
)
_solve_check
(
N
,
INFO
)
RCOND
,
INFO
=
_xgecon
(
LU
,
norm
,
"1"
)
_solve_check
(
N
,
INFO
,
True
,
RCOND
)
if
INFO1
!=
0
or
INFO2
!=
0
:
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
...
...
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
浏览文件 @
672a4829
...
...
@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
...
...
@@ -107,14 +106,11 @@ def _lu_solve(
b
:
np
.
ndarray
,
trans
:
_Trans
,
overwrite_b
:
bool
,
check_finite
:
bool
,
):
"""
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
"""
return
linalg
.
lu_solve
(
lu_and_piv
,
b
,
trans
=
trans
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
)
return
linalg
.
lu_solve
(
lu_and_piv
,
b
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
@overload
(
_lu_solve
)
...
...
@@ -123,8 +119,7 @@ def lu_solve_impl(
b
:
np
.
ndarray
,
trans
:
_Trans
,
overwrite_b
:
bool
,
check_finite
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
lu
,
_piv
=
lu_and_piv
_check_linalg_matrix
(
lu
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_solve"
)
...
...
@@ -137,13 +132,11 @@ def lu_solve_impl(
b
:
np
.
ndarray
,
trans
:
_Trans
,
overwrite_b
:
bool
,
check_finite
:
bool
,
)
->
np
.
ndarray
:
n
=
np
.
int32
(
lu
.
shape
[
0
]
)
X
,
info
=
_getrs
(
LU
=
lu
,
B
=
b
,
IPIV
=
piv
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
X
,
INFO
=
_getrs
(
LU
=
lu
,
B
=
b
,
IPIV
=
piv
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
_solve_check
(
n
,
INFO
)
if
info
!=
0
:
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
...
...
pytensor/link/numba/dispatch/linalg/solve/norm.py
deleted
100644 → 0
浏览文件 @
b2d8bc24
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_xlange
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
float
:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return
# type: ignore
@overload
(
_xlange
)
def
xlange_impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
Callable
[[
np
.
ndarray
,
str
],
float
]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"norm"
)
dtype
=
A
.
dtype
numba_lange
=
_LAPACK
()
.
numba_xlange
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
):
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
NORM
=
(
val_to_int_ptr
(
ord
(
order
))
if
order
is
not
None
else
val_to_int_ptr
(
ord
(
"1"
))
)
WORK
=
np
.
empty
(
_M
,
dtype
=
dtype
)
# type: ignore
result
=
numba_lange
(
NORM
,
M
,
N
,
A_copy
.
ctypes
,
LDA
,
WORK
.
ctypes
)
return
result
return
impl
pytensor/link/numba/dispatch/linalg/solve/posdef.py
浏览文件 @
672a4829
...
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
...
...
@@ -27,8 +25,6 @@ def _posv(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
...
...
@@ -43,10 +39,8 @@ def posv_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
],
]:
ensure_lapack
()
...
...
@@ -62,8 +56,6 @@ def posv_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_solve_check_input_shapes
(
A
,
B
)
...
...
@@ -115,60 +107,12 @@ def posv_impl(
return
impl
def
_pocon
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return
# type: ignore
@overload
(
_pocon
)
def
pocon_impl
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"pocon"
)
dtype
=
A
.
dtype
numba_pocon
=
_LAPACK
()
.
numba_xpocon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
anorm
:
float
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
))
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
3
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_pocon
(
UPLO
,
N
,
A_copy
.
ctypes
,
LDA
,
ANORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_psd
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
...
...
@@ -179,7 +123,7 @@ def _solve_psd(
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
transposed
=
transposed
,
assume_a
=
"pos"
,
)
...
...
@@ -192,9 +136,8 @@ def solve_psd_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
...
@@ -206,18 +149,14 @@ def solve_psd_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
C
,
x
,
info
=
_posv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
,
check_finite
,
transposed
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
_C
,
x
,
info
=
_posv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
rcond
,
info
=
_pocon
(
C
,
_xlange
(
A
))
_solve_check
(
A
.
shape
[
-
1
],
info
=
info
,
lamch
=
True
,
rcond
=
rcond
)
if
info
!=
0
:
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
...
...
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
浏览文件 @
672a4829
...
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
...
...
@@ -121,61 +119,12 @@ def sysv_impl(
return
impl
def
_sycon
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return
# type: ignore
@overload
(
_sycon
)
def
sycon_impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"sycon"
)
dtype
=
A
.
dtype
numba_sycon
=
_LAPACK
()
.
numba_xsycon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
UPLO
=
val_to_int_ptr
(
ord
(
"U"
))
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
2
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_sycon
(
UPLO
,
N
,
A_copy
.
ctypes
,
LDA
,
ipiv
.
ctypes
,
ANORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_symmetric
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
...
...
@@ -186,7 +135,7 @@ def _solve_symmetric(
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
assume_a
=
"sym"
,
transposed
=
transposed
,
)
...
...
@@ -199,9 +148,8 @@ def solve_symmetric_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
...
@@ -213,16 +161,14 @@ def solve_symmetric_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
lu
,
x
,
ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
_lu
,
x
,
_ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
rcond
,
info
=
_sycon
(
lu
,
ipiv
,
_xlange
(
A
,
order
=
"I"
))
_solve_check
(
A
.
shape
[
-
1
],
info
,
True
,
rcond
)
if
info
!=
0
:
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
...
...
pytensor/link/numba/dispatch/linalg/solve/triangular.py
浏览文件 @
672a4829
...
...
@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
def
_solve_triangular
(
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
b_ndim
=
1
,
overwrite_b
=
False
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
overwrite_b
=
False
):
"""
Thin wrapper around scipy.linalg.solve_triangular.
...
...
@@ -39,11 +38,12 @@ def _solve_triangular(
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
check_finite
=
False
,
)
@overload
(
_solve_triangular
)
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
overwrite_b
):
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve_triangular"
)
...
...
@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
"This function is not expected to work with complex numbers yet"
)
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
overwrite_b
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d
=
B
.
ndim
==
1
if
A
.
flags
.
f_contiguous
or
(
A
.
flags
.
c_contiguous
and
trans
in
(
0
,
1
)):
...
...
@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
LDB
,
INFO
,
)
_solve_check
(
int_ptr_to_val
(
LDA
),
int_ptr_to_val
(
INFO
)
)
if
int_ptr_to_val
(
INFO
)
!=
0
:
B_copy
=
np
.
full_like
(
B_copy
,
np
.
nan
)
if
B_is_1d
:
return
B_copy
[
...
,
0
]
...
...
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
浏览文件 @
672a4829
...
...
@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
...
...
@@ -202,83 +201,12 @@ def gttrs_impl(
return
impl
def
_gtcon
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
tuple
[
ndarray
,
int
]:
"""Placeholder for computing the condition number of a tridiagonal system."""
return
# type: ignore
@overload
(
_gtcon
)
def
gtcon_impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
Callable
[
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
float
,
str
],
tuple
[
ndarray
,
int
]
]:
ensure_lapack
()
_check_linalg_matrix
(
dl
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
d
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
du
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
du2
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_dtypes_match
((
dl
,
d
,
du
,
du2
),
func_name
=
"gtcon"
)
_check_linalg_matrix
(
ipiv
,
ndim
=
1
,
dtype
=
int32
,
func_name
=
"gtcon"
)
dtype
=
d
.
dtype
numba_gtcon
=
_LAPACK
()
.
numba_xgtcon
(
dtype
)
def
impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
tuple
[
ndarray
,
int
]:
n
=
np
.
int32
(
d
.
shape
[
-
1
])
rcond
=
np
.
empty
(
1
,
dtype
=
dtype
)
work
=
np
.
empty
(
2
*
n
,
dtype
=
dtype
)
iwork
=
np
.
empty
(
n
,
dtype
=
np
.
int32
)
info
=
val_to_int_ptr
(
0
)
numba_gtcon
(
val_to_int_ptr
(
ord
(
norm
)),
val_to_int_ptr
(
n
),
dl
.
ctypes
,
d
.
ctypes
,
du
.
ctypes
,
du2
.
ctypes
,
ipiv
.
ctypes
,
np
.
array
(
anorm
,
dtype
=
dtype
)
.
ctypes
,
rcond
.
ctypes
,
work
.
ctypes
,
iwork
.
ctypes
,
info
,
)
return
rcond
,
int_ptr_to_val
(
info
)
return
impl
def
_solve_tridiagonal
(
a
:
ndarray
,
b
:
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""
...
...
@@ -290,7 +218,7 @@ def _solve_tridiagonal(
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
transposed
=
transposed
,
assume_a
=
"tridiagonal"
,
)
...
...
@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
ndarray
,
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
ndarray
]:
)
->
Callable
[[
ndarray
,
ndarray
,
bool
,
bool
,
bool
,
bool
],
ndarray
]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
...
@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl(
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
ndarray
:
n
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
norm
=
"1"
if
transposed
:
A
=
A
.
T
dl
,
d
,
du
=
np
.
diag
(
A
,
-
1
),
np
.
diag
(
A
,
0
),
np
.
diag
(
A
,
1
)
anorm
=
tridiagonal_norm
(
du
,
d
,
dl
)
dl
,
d
,
du
,
du2
,
IPIV
,
INFO
=
_gttrf
(
dl
,
d
,
du
,
du2
,
ipiv
,
info1
=
_gttrf
(
dl
,
d
,
du
,
overwrite_dl
=
True
,
overwrite_d
=
True
,
overwrite_du
=
True
)
_solve_check
(
n
,
INFO
)
X
,
INFO
=
_gttrs
(
dl
,
d
,
du
,
du2
,
IPIV
,
B
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
X
,
info2
=
_gttrs
(
dl
,
d
,
du
,
du2
,
ipiv
,
B
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
)
_solve_check
(
n
,
INFO
)
RCOND
,
INFO
=
_gtcon
(
dl
,
d
,
du
,
du2
,
IPIV
,
anorm
,
norm
)
_solve_check
(
n
,
INFO
,
True
,
RCOND
)
if
info1
!=
0
or
info2
!=
0
:
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
...
...
@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
)
return
dl
,
d
,
du
,
du2
,
ipiv
cache_
key
=
1
return
lu_factor_tridiagonal
,
cache_
key
cache_
version
=
2
return
lu_factor_tridiagonal
,
cache_
version
@register_funcify_default_op_cache_key
(
SolveLUFactorTridiagonal
)
...
...
@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
ipiv
=
ipiv
.
astype
(
np
.
int32
)
if
cast_b
:
b
=
b
.
astype
(
out_dtype
)
x
,
_
=
_gttrs
(
x
,
info
=
_gttrs
(
dl
,
d
,
du
,
...
...
@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b
=
overwrite_b
,
trans
=
transposed
,
)
if
info
!=
0
:
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
cache_
key
=
1
return
solve_lu_factor_tridiagonal
,
cache_
key
cache_
version
=
2
return
solve_lu_factor_tridiagonal
,
cache_
version
pytensor/link/numba/dispatch/linalg/utils.py
浏览文件 @
672a4829
from
collections.abc
import
Callable
,
Sequence
from
collections.abc
import
Sequence
import
numba
from
numba.core
import
types
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numpy.linalg
import
LinAlgError
from
numba.np.linalg
import
_copy_to_fortran_order
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
val_to_int_ptr
,
)
@numba_basic.numba_njit
(
inline
=
"always"
)
...
...
@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
if
first_dtype
!=
other_dtype
:
msg
=
f
"{func_name} only supported for matching dtypes, got {dtypes}"
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_solve_check
(
n
,
info
,
lamch
=
False
,
rcond
=
None
):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if
info
<
0
:
# TODO: figure out how to do an fstring here
msg
=
"LAPACK reported an illegal value in input"
raise
ValueError
(
msg
)
elif
0
<
info
:
raise
LinAlgError
(
"Matrix is singular."
)
if
lamch
:
E
=
_xlamch
(
"E"
)
if
rcond
<
E
:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print
(
# noqa: T201
"Ill-conditioned matrix, rcond="
,
rcond
,
", result may not be accurate."
)
def
_xlamch
(
kind
:
str
=
"E"
):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload
(
_xlamch
)
def
xlamch_impl
(
kind
:
str
=
"E"
)
->
Callable
[[
str
],
float
]:
"""
Compute the machine precision for a given floating point type.
"""
from
pytensor
import
config
ensure_lapack
()
w_type
=
_get_underlying_float
(
config
.
floatX
)
if
w_type
==
"float32"
:
dtype
=
types
.
float32
elif
w_type
==
"float64"
:
dtype
=
types
.
float64
else
:
raise
NotImplementedError
(
"Unsupported dtype"
)
numba_lamch
=
_LAPACK
()
.
numba_xlamch
(
dtype
)
def
impl
(
kind
:
str
=
"E"
)
->
float
:
KIND
=
val_to_int_ptr
(
ord
(
kind
))
return
numba_lamch
(
KIND
)
# type: ignore
return
impl
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
672a4829
...
...
@@ -58,8 +58,6 @@ def numba_funcify_Cholesky(op, node, **kwargs):
"""
lower
=
op
.
lower
overwrite_a
=
op
.
overwrite_a
check_finite
=
op
.
check_finite
on_error
=
op
.
on_error
inp_dtype
=
node
.
inputs
[
0
]
.
type
.
numpy_dtype
if
inp_dtype
.
kind
==
"c"
:
...
...
@@ -77,30 +75,11 @@ def numba_funcify_Cholesky(op, node, **kwargs):
if
discrete_inp
:
a
=
a
.
astype
(
out_dtype
)
elif
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to cholesky"
)
res
,
info
=
_cholesky
(
a
,
lower
,
overwrite_a
,
check_finite
)
if
on_error
==
"raise"
:
if
info
>
0
:
raise
np
.
linalg
.
LinAlgError
(
"Input to cholesky is not positive definite"
)
if
info
<
0
:
raise
ValueError
(
'LAPACK reported an illegal value in input on entry to "POTRF."'
)
else
:
if
info
!=
0
:
res
=
np
.
full_like
(
res
,
np
.
nan
)
return
res
return
_cholesky
(
a
,
lower
,
overwrite_a
)
cache_
key
=
1
return
cholesky
,
cache_
key
cache_
version
=
2
return
cholesky
,
cache_
version
@register_funcify_default_op_cache_key
(
PivotToPermutations
)
...
...
@@ -116,8 +95,8 @@ def pivot_to_permutation(op, node, **kwargs):
return
np
.
argsort
(
p_inv
)
cache_
key
=
1
return
numba_pivot_to_permutation
,
cache_
key
cache_
version
=
2
return
numba_pivot_to_permutation
,
cache_
version
@register_funcify_default_op_cache_key
(
LU
)
...
...
@@ -131,7 +110,6 @@ def numba_funcify_LU(op, node, **kwargs):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
permute_l
=
op
.
permute_l
check_finite
=
op
.
check_finite
p_indices
=
op
.
p_indices
overwrite_a
=
op
.
overwrite_a
...
...
@@ -151,17 +129,11 @@ def numba_funcify_LU(op, node, **kwargs):
if
discrete_inp
:
a
=
a
.
astype
(
out_dtype
)
elif
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to lu"
)
if
p_indices
:
res
=
_lu_1
(
a
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
...
...
@@ -169,7 +141,6 @@ def numba_funcify_LU(op, node, **kwargs):
res
=
_lu_2
(
a
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
...
...
@@ -177,15 +148,14 @@ def numba_funcify_LU(op, node, **kwargs):
res
=
_lu_3
(
a
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
return
res
cache_
key
=
1
return
lu
,
cache_
key
cache_
version
=
2
return
lu
,
cache_
version
@register_funcify_default_op_cache_key
(
LUFactor
)
...
...
@@ -198,7 +168,6 @@ def numba_funcify_LUFactor(op, node, **kwargs):
print
(
"LUFactor requires casting discrete input to float"
)
# noqa: T201
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
@numba_basic.numba_njit
...
...
@@ -211,18 +180,13 @@ def numba_funcify_LUFactor(op, node, **kwargs):
if
discrete_inp
:
a
=
a
.
astype
(
out_dtype
)
elif
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to cholesky"
)
LU
,
piv
=
_lu_factor
(
a
,
overwrite_a
)
return
LU
,
piv
cache_
key
=
1
return
lu_factor
,
cache_
key
cache_
version
=
2
return
lu_factor
,
cache_
version
@register_funcify_default_op_cache_key
(
BlockDiagonal
)
...
...
@@ -288,8 +252,8 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
globals
()
|
{
"np"
:
np
},
)
cache_
key
=
1
return
numba_basic
.
numba_njit
(
block_diag
),
cache_
key
cache_
version
=
1
return
numba_basic
.
numba_njit
(
block_diag
),
cache_
version
@register_funcify_default_op_cache_key
(
Solve
)
...
...
@@ -306,12 +270,9 @@ def numba_funcify_Solve(op, node, **kwargs):
if
must_cast_B
and
config
.
compiler_verbose
:
print
(
"Solve requires casting second input `b`"
)
# noqa: T201
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
assume_a
=
op
.
assume_a
lower
=
op
.
lower
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
overwrite_b
=
op
.
overwrite_b
transposed
=
False
# TODO: Solve doesnt currently allow the transposed argument
...
...
@@ -344,30 +305,18 @@ def numba_funcify_Solve(op, node, **kwargs):
a
=
a
.
astype
(
out_dtype
)
if
must_cast_B
:
b
=
b
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to solve"
)
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to solve"
)
res
=
solve_fn
(
a
,
b
,
lower
,
overwrite_a
,
overwrite_b
,
check_finite
,
transposed
)
return
res
cache_key
=
1
return
solve
,
cache_key
return
solve_fn
(
a
,
b
,
lower
,
overwrite_a
,
overwrite_b
,
transposed
)
cache_version
=
2
return
solve
,
cache_version
@register_funcify_default_op_cache_key
(
SolveTriangular
)
def
numba_funcify_SolveTriangular
(
op
,
node
,
**
kwargs
):
lower
=
op
.
lower
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
overwrite_b
=
op
.
overwrite_b
b_ndim
=
op
.
b_ndim
A_dtype
,
b_dtype
=
(
i
.
type
.
numpy_dtype
for
i
in
node
.
inputs
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
...
...
@@ -389,37 +338,24 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
a
=
a
.
astype
(
out_dtype
)
if
must_cast_B
:
b
=
b
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to solve_triangular"
)
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to solve_triangular"
)
res
=
_solve_triangular
(
return
_solve_triangular
(
a
,
b
,
trans
=
0
,
# transposing is handled explicitly on the graph, so we never use this argument
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
b_ndim
=
b_ndim
,
)
return
res
cache_key
=
1
return
solve_triangular
,
cache_key
cache_version
=
2
return
solve_triangular
,
cache_version
@register_funcify_default_op_cache_key
(
CholeskySolve
)
def
numba_funcify_CholeskySolve
(
op
,
node
,
**
kwargs
):
lower
=
op
.
lower
overwrite_b
=
op
.
overwrite_b
check_finite
=
op
.
check_finite
c_dtype
,
b_dtype
=
(
i
.
type
.
numpy_dtype
for
i
in
node
.
inputs
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
...
...
@@ -439,36 +375,24 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
return
np
.
zeros
(
b
.
shape
,
dtype
=
out_dtype
)
if
must_cast_c
:
c
=
c
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
c
),
np
.
isnan
(
c
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to cho_solve"
)
if
must_cast_b
:
b
=
b
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to cho_solve"
)
return
_cho_solve
(
c
,
b
,
lower
=
lower
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
)
cache_
key
=
1
return
cho_solve
,
cache_
key
cache_
version
=
2
return
cho_solve
,
cache_
version
@register_funcify_default_op_cache_key
(
QR
)
def
numba_funcify_QR
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
check_finite
=
op
.
check_finite
pivoting
=
op
.
pivoting
overwrite_a
=
op
.
overwrite_a
...
...
@@ -481,12 +405,6 @@ def numba_funcify_QR(op, node, **kwargs):
@numba_basic.numba_njit
def
qr
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to qr"
)
if
integer_input
:
a
=
a
.
astype
(
out_dtype
)
...
...
@@ -496,7 +414,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
Q
,
R
,
P
...
...
@@ -506,7 +423,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
Q
,
R
...
...
@@ -516,7 +432,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
R
,
P
...
...
@@ -526,7 +441,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
R
...
...
@@ -536,7 +450,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
H
,
tau
,
R
,
P
...
...
@@ -546,7 +459,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
H
,
tau
,
R
...
...
@@ -555,5 +467,5 @@ def numba_funcify_QR(op, node, **kwargs):
f
"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
)
cache_
key
=
1
return
qr
,
cache_
key
cache_
version
=
2
return
qr
,
cache_
version
pytensor/link/numba/dispatch/tensor_basic.py
浏览文件 @
672a4829
...
...
@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
out
[
...
,
i
]
=
new_entry
return
out
cache_
key
=
1
return
extract_diag
,
cache_
key
cache_
version
=
1
return
extract_diag
,
cache_
version
@register_funcify_default_op_cache_key
(
Eye
)
...
...
pytensor/tensor/_linalg/solve/rewriting.py
浏览文件 @
672a4829
...
...
@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so
from
pytensor.tensor.variable
import
TensorVariable
def
decompose_A
(
A
,
assume_a
,
check_finite
,
lower
):
def
decompose_A
(
A
,
assume_a
,
lower
):
if
assume_a
==
"gen"
:
return
lu_factor
(
A
,
check_finite
=
check_finite
)
return
lu_factor
(
A
)
elif
assume_a
==
"tridiagonal"
:
# We didn't implement check_finite for tridiagonal LU factorization
return
tridiagonal_lu_factor
(
A
)
elif
assume_a
==
"pos"
:
return
cholesky
(
A
,
lower
=
lower
,
check_finite
=
check_finite
)
return
cholesky
(
A
,
lower
=
lower
)
else
:
raise
NotImplementedError
...
...
@@ -36,7 +35,6 @@ def solve_decomposed_system(
A_decomp
,
b
,
transposed
=
False
,
lower
=
False
,
*
,
core_solve_op
:
Solve
):
b_ndim
=
core_solve_op
.
b_ndim
check_finite
=
core_solve_op
.
check_finite
assume_a
=
core_solve_op
.
assume_a
if
assume_a
==
"gen"
:
...
...
@@ -45,10 +43,8 @@ def solve_decomposed_system(
b
,
b_ndim
=
b_ndim
,
trans
=
transposed
,
check_finite
=
check_finite
,
)
elif
assume_a
==
"tridiagonal"
:
# We didn't implement check_finite for tridiagonal LU solve
return
tridiagonal_lu_solve
(
A_decomp
,
b
,
...
...
@@ -61,7 +57,6 @@ def solve_decomposed_system(
(
A_decomp
,
lower
),
b
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
,
)
else
:
raise
NotImplementedError
...
...
@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps(
):
return
None
# If any Op had check_finite=True, we also do it for the LU decomposition
check_finite_decomp
=
False
for
client
,
_
in
A_solve_clients_and_transpose
:
if
client
.
op
.
core_op
.
check_finite
:
check_finite_decomp
=
True
break
lower
=
node
.
op
.
core_op
.
lower
A_decomp
=
decompose_A
(
A
,
assume_a
=
assume_a
,
check_finite
=
check_finite_decomp
,
lower
=
lower
)
A_decomp
=
decompose_A
(
A
,
assume_a
=
assume_a
,
lower
=
lower
)
replacements
=
{}
for
client
,
transposed
in
A_solve_clients_and_transpose
:
...
...
pytensor/tensor/slinalg.py
浏览文件 @
672a4829
...
...
@@ -6,7 +6,7 @@ from typing import Literal, cast
import
numpy
as
np
import
scipy.linalg
as
scipy_linalg
from
scipy.linalg
import
LinAlgError
,
LinAlgWarning
,
get_lapack_funcs
from
scipy.linalg
import
get_lapack_funcs
import
pytensor
from
pytensor
import
ifelse
...
...
@@ -14,7 +14,7 @@ from pytensor import tensor as pt
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.raise_op
import
Assert
from
pytensor.raise_op
import
Assert
,
CheckAndRaise
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
math
as
ptm
...
...
@@ -32,22 +32,16 @@ logger = logging.getLogger(__name__)
class
Cholesky
(
Op
):
# TODO: LAPACK wrapper with in-place behavior, for solve also
__props__
=
(
"lower"
,
"
check_finite"
,
"on_error"
,
"
overwrite_a"
)
__props__
=
(
"lower"
,
"overwrite_a"
)
gufunc_signature
=
"(m,m)->(m,m)"
def
__init__
(
self
,
*
,
lower
:
bool
=
True
,
check_finite
:
bool
=
False
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"raise"
,
overwrite_a
:
bool
=
False
,
):
self
.
lower
=
lower
self
.
check_finite
=
check_finite
if
on_error
not
in
(
"raise"
,
"nan"
):
raise
ValueError
(
'on_error must be one of "raise" or ""nan"'
)
self
.
on_error
=
on_error
self
.
overwrite_a
=
overwrite_a
if
self
.
overwrite_a
:
...
...
@@ -77,13 +71,6 @@ class Cholesky(Op):
out
[
0
]
=
np
.
empty_like
(
x
,
dtype
=
potrf
.
dtype
)
return
if
self
.
check_finite
and
not
np
.
isfinite
(
x
)
.
all
():
if
self
.
on_error
==
"nan"
:
out
[
0
]
=
np
.
full
(
x
.
shape
,
np
.
nan
,
dtype
=
potrf
.
dtype
)
return
else
:
raise
ValueError
(
"array must not contain infs or NaNs"
)
# Squareness check
if
x
.
shape
[
0
]
!=
x
.
shape
[
1
]:
raise
ValueError
(
...
...
@@ -104,17 +91,8 @@ class Cholesky(Op):
c
,
info
=
potrf
(
x
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
clean
=
True
)
if
info
!=
0
:
if
self
.
on_error
==
"nan"
:
out
[
0
]
=
np
.
full
(
x
.
shape
,
np
.
nan
,
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
)
elif
info
>
0
:
raise
scipy_linalg
.
LinAlgError
(
f
"{info}-th leading minor of the array is not positive definite"
)
elif
info
<
0
:
raise
ValueError
(
f
"LAPACK reported an illegal value in {-info}-th argument "
f
'on entry to "POTRF".'
)
c
[
...
]
=
np
.
nan
out
[
0
]
=
c
else
:
# Transpose result if input was transposed
out
[
0
]
=
c
.
T
if
c_contiguous_input
else
c
...
...
@@ -135,13 +113,6 @@ class Cholesky(Op):
dz
=
gradients
[
0
]
chol_x
=
outputs
[
0
]
# Replace the cholesky decomposition with 1 if there are nans
# or solve_upper_triangular will throw a ValueError.
if
self
.
on_error
==
"nan"
:
ok
=
~
ptm
.
any
(
ptm
.
isnan
(
chol_x
))
chol_x
=
ptb
.
switch
(
ok
,
chol_x
,
1
)
dz
=
ptb
.
switch
(
ok
,
dz
,
1
)
# deal with upper triangular by converting to lower triangular
if
not
self
.
lower
:
chol_x
=
chol_x
.
T
...
...
@@ -165,10 +136,7 @@ class Cholesky(Op):
else
:
grad
=
ptb
.
triu
(
s
+
s
.
T
)
-
ptb
.
diag
(
ptb
.
diagonal
(
s
))
if
self
.
on_error
==
"nan"
:
return
[
ptb
.
switch
(
ok
,
grad
,
np
.
nan
)]
else
:
return
[
grad
]
return
[
grad
]
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
not
allowed_inplace_inputs
:
...
...
@@ -182,9 +150,9 @@ def cholesky(
x
:
"TensorLike"
,
lower
:
bool
=
True
,
*
,
check_finite
:
bool
=
Fals
e
,
check_finite
:
bool
=
Tru
e
,
overwrite_a
:
bool
=
False
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"
raise
"
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"
nan
"
,
):
"""
Return a triangular matrix square root of positive semi-definite `x`.
...
...
@@ -196,8 +164,8 @@ def cholesky(
x: tensor_like
lower : bool, default=True
Whether to return the lower or upper cholesky factor
check_finite : bool
, default=False
Whether to check that the input matrix contains only finite number
s.
check_finite : bool
Unused by PyTensor. PyTensor will return nan if the operation fail
s.
overwrite_a: bool, ignored
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
for consistency with scipy.linalg.cholesky.
...
...
@@ -228,10 +196,19 @@ def cholesky(
assert np.allclose(L_value @ L_value.T, x_value)
"""
res
=
Blockwise
(
Cholesky
(
lower
=
lower
))(
x
)
return
Blockwise
(
Cholesky
(
lower
=
lower
,
on_error
=
on_error
,
check_finite
=
check_finite
)
)(
x
)
if
on_error
==
"raise"
:
# For back-compatibility
warnings
.
warn
(
'Cholesky on_raise == "raise" is deprecated. The operation will return nan when in fails. Setting this argument will fail in the future'
,
FutureWarning
,
)
res
=
CheckAndRaise
(
np
.
linalg
.
LinAlgError
,
"Matrix is not positive definite"
)(
res
,
~
ptm
.
isnan
(
res
)
.
any
()
)
return
res
class
SolveBase
(
Op
):
...
...
@@ -239,7 +216,6 @@ class SolveBase(Op):
__props__
:
tuple
[
str
,
...
]
=
(
"lower"
,
"check_finite"
,
"b_ndim"
,
"overwrite_a"
,
"overwrite_b"
,
...
...
@@ -249,13 +225,11 @@ class SolveBase(Op):
self
,
*
,
lower
=
False
,
check_finite
=
True
,
b_ndim
,
overwrite_a
=
False
,
overwrite_b
=
False
,
):
self
.
lower
=
lower
self
.
check_finite
=
check_finite
assert
b_ndim
in
(
1
,
2
)
self
.
b_ndim
=
b_ndim
...
...
@@ -358,7 +332,6 @@ def _default_b_ndim(b, b_ndim):
class
CholeskySolve
(
SolveBase
):
__props__
=
(
"lower"
,
"check_finite"
,
"b_ndim"
,
"overwrite_b"
,
)
...
...
@@ -366,7 +339,6 @@ class CholeskySolve(SolveBase):
def
__init__
(
self
,
**
kwargs
):
if
kwargs
.
get
(
"overwrite_a"
,
False
):
raise
ValueError
(
"overwrite_a is not supported for CholeskySolve"
)
kwargs
.
setdefault
(
"lower"
,
True
)
super
()
.
__init__
(
**
kwargs
)
def
make_node
(
self
,
*
inputs
):
...
...
@@ -387,9 +359,6 @@ class CholeskySolve(SolveBase):
(
potrs
,)
=
get_lapack_funcs
((
"potrs"
,),
(
c
,
b
))
if
self
.
check_finite
and
not
(
np
.
isfinite
(
c
)
.
all
()
and
np
.
isfinite
(
b
)
.
all
()):
raise
ValueError
(
"array must not contain infs or NaNs"
)
if
c
.
shape
[
0
]
!=
c
.
shape
[
1
]:
raise
ValueError
(
"The factored matrix c is not square."
)
if
c
.
shape
[
1
]
!=
b
.
shape
[
0
]:
...
...
@@ -402,7 +371,7 @@ class CholeskySolve(SolveBase):
x
,
info
=
potrs
(
c
,
b
,
lower
=
self
.
lower
,
overwrite_b
=
self
.
overwrite_b
)
if
info
!=
0
:
raise
ValueError
(
f
"illegal value in {-info}th argument of internal potrs"
)
x
[
...
]
=
np
.
nan
output_storage
[
0
][
0
]
=
x
...
...
@@ -423,7 +392,6 @@ def cho_solve(
c_and_lower
:
tuple
[
TensorLike
,
bool
],
b
:
TensorLike
,
*
,
check_finite
:
bool
=
True
,
b_ndim
:
int
|
None
=
None
,
):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
...
...
@@ -434,33 +402,26 @@ def cho_solve(
Cholesky factorization of a, as given by cho_factor
b : TensorLike
Right-hand side
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
check_finite : bool
Unused by PyTensor. PyTensor will return nan if the operation fails.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
"""
A
,
lower
=
c_and_lower
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
return
Blockwise
(
CholeskySolve
(
lower
=
lower
,
check_finite
=
check_finite
,
b_ndim
=
b_ndim
)
)(
A
,
b
)
return
Blockwise
(
CholeskySolve
(
lower
=
lower
,
b_ndim
=
b_ndim
))(
A
,
b
)
class
LU
(
Op
):
"""Decompose a matrix into lower and upper triangular matrices."""
__props__
=
(
"permute_l"
,
"overwrite_a"
,
"
check_finite"
,
"
p_indices"
)
__props__
=
(
"permute_l"
,
"overwrite_a"
,
"p_indices"
)
def
__init__
(
self
,
*
,
permute_l
=
False
,
overwrite_a
=
False
,
check_finite
=
True
,
p_indices
=
False
):
def
__init__
(
self
,
*
,
permute_l
=
False
,
overwrite_a
=
False
,
p_indices
=
False
):
if
permute_l
and
p_indices
:
raise
ValueError
(
"Only one of permute_l and p_indices can be True"
)
self
.
permute_l
=
permute_l
self
.
check_finite
=
check_finite
self
.
p_indices
=
p_indices
self
.
overwrite_a
=
overwrite_a
...
...
@@ -523,7 +484,6 @@ class LU(Op):
A
,
permute_l
=
self
.
permute_l
,
overwrite_a
=
self
.
overwrite_a
,
check_finite
=
self
.
check_finite
,
p_indices
=
self
.
p_indices
,
)
...
...
@@ -563,7 +523,7 @@ class LU(Op):
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
# We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass
P_or_indices
,
L
,
U
=
lu
(
# type: ignore
A
,
permute_l
=
False
,
check_finite
=
self
.
check_finite
,
p_indices
=
False
A
,
permute_l
=
False
,
p_indices
=
False
)
else
:
...
...
@@ -621,8 +581,8 @@ def lu(
permute_l: bool
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
be returned in this case, and PL will not be lower triangular.
check_finite: bool
Whether to check that the input matrix contains only finite number
s.
check_finite
: bool
Unused by PyTensor. PyTensor will return nan if the operation fail
s.
p_indices: bool
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
itself.
...
...
@@ -640,9 +600,7 @@ def lu(
return
cast
(
tuple
[
TensorVariable
,
TensorVariable
,
TensorVariable
]
|
tuple
[
TensorVariable
,
TensorVariable
],
Blockwise
(
LU
(
permute_l
=
permute_l
,
p_indices
=
p_indices
,
check_finite
=
check_finite
)
)(
a
),
Blockwise
(
LU
(
permute_l
=
permute_l
,
p_indices
=
p_indices
))(
a
),
)
...
...
@@ -680,12 +638,11 @@ def pivot_to_permutation(p: TensorLike, inverse=False):
class
LUFactor
(
Op
):
__props__
=
(
"overwrite_a"
,
"check_finite"
)
__props__
=
(
"overwrite_a"
,)
gufunc_signature
=
"(m,m)->(m,m),(m)"
def
__init__
(
self
,
*
,
overwrite_a
=
False
,
check_finite
=
True
):
def
__init__
(
self
,
*
,
overwrite_a
=
False
):
self
.
overwrite_a
=
overwrite_a
self
.
check_finite
=
check_finite
if
self
.
overwrite_a
:
self
.
destroy_map
=
{
1
:
[
0
]}
...
...
@@ -723,21 +680,10 @@ class LUFactor(Op):
outputs
[
1
][
0
]
=
np
.
array
([],
dtype
=
np
.
int32
)
return
if
self
.
check_finite
and
not
np
.
isfinite
(
A
)
.
all
():
raise
ValueError
(
"array must not contain infs or NaNs"
)
(
getrf
,)
=
get_lapack_funcs
((
"getrf"
,),
(
A
,))
LU
,
p
,
info
=
getrf
(
A
,
overwrite_a
=
self
.
overwrite_a
)
if
info
<
0
:
raise
ValueError
(
f
"illegal value in {-info}th argument of internal getrf (lu_factor)"
)
if
info
>
0
:
warnings
.
warn
(
f
"Diagonal number {info} is exactly zero. Singular matrix."
,
LinAlgWarning
,
stacklevel
=
2
,
)
if
info
!=
0
:
LU
[
...
]
=
np
.
nan
outputs
[
0
][
0
]
=
LU
outputs
[
1
][
0
]
=
p
...
...
@@ -782,7 +728,7 @@ def lu_factor(
a: TensorLike
Matrix to be factorized
check_finite: bool
Whether to check that the input matrix contains only finite number
s.
Unused by PyTensor. PyTensor will return nan if the operation fail
s.
overwrite_a: bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
...
...
@@ -796,7 +742,7 @@ def lu_factor(
return
cast
(
tuple
[
TensorVariable
,
TensorVariable
],
Blockwise
(
LUFactor
(
check_finite
=
check_finite
))(
a
),
Blockwise
(
LUFactor
())(
a
),
)
...
...
@@ -806,7 +752,6 @@ def _lu_solve(
b
:
TensorLike
,
trans
:
bool
=
False
,
b_ndim
:
int
|
None
=
None
,
check_finite
:
bool
=
True
,
):
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
...
...
@@ -824,7 +769,6 @@ def _lu_solve(
unit_diagonal
=
not
trans
,
trans
=
trans
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
,
)
x
=
solve_triangular
(
...
...
@@ -834,7 +778,6 @@ def _lu_solve(
unit_diagonal
=
trans
,
trans
=
trans
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
,
)
# TODO: Use PermuteRows(inverse=True) on x
...
...
@@ -867,7 +810,7 @@ def lu_solve(
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
check_finite: bool
If True, check that the input matrices contain only finite numbers. Default is True
.
Unused by PyTensor. PyTensor will return nan if the operation fails
.
overwrite_b: bool
Ignored by Pytensor. Pytensor will always compute inplace when possible.
"""
...
...
@@ -876,9 +819,7 @@ def lu_solve(
signature
=
"(m,m),(m),(m)->(m)"
else
:
signature
=
"(m,m),(m),(m,n)->(m,n)"
partialled_func
=
partial
(
_lu_solve
,
trans
=
trans
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
)
partialled_func
=
partial
(
_lu_solve
,
trans
=
trans
,
b_ndim
=
b_ndim
)
return
pt
.
vectorize
(
partialled_func
,
signature
=
signature
)(
*
LU_and_pivots
,
b
)
...
...
@@ -888,7 +829,6 @@ class SolveTriangular(SolveBase):
__props__
=
(
"unit_diagonal"
,
"lower"
,
"check_finite"
,
"b_ndim"
,
"overwrite_b"
,
)
...
...
@@ -905,10 +845,7 @@ class SolveTriangular(SolveBase):
def
perform
(
self
,
node
,
inputs
,
outputs
):
A
,
b
=
inputs
if
self
.
check_finite
and
not
(
np
.
isfinite
(
A
)
.
all
()
and
np
.
isfinite
(
b
)
.
all
()):
raise
ValueError
(
"array must not contain infs or NaNs"
)
if
len
(
A
.
shape
)
!=
2
or
A
.
shape
[
0
]
!=
A
.
shape
[
1
]:
if
A
.
ndim
!=
2
or
A
.
shape
[
0
]
!=
A
.
shape
[
1
]:
raise
ValueError
(
"expected square matrix"
)
if
A
.
shape
[
0
]
!=
b
.
shape
[
0
]:
...
...
@@ -941,12 +878,8 @@ class SolveTriangular(SolveBase):
unitdiag
=
self
.
unit_diagonal
,
)
if
info
>
0
:
raise
LinAlgError
(
f
"singular matrix: resolution failed at diagonal {info - 1}"
)
elif
info
<
0
:
raise
ValueError
(
f
"illegal value in {-info}-th argument of internal trtrs"
)
if
info
!=
0
:
x
[
...
]
=
np
.
nan
outputs
[
0
][
0
]
=
x
...
...
@@ -998,9 +931,7 @@ def solve_triangular(
unit_diagonal: bool, optional
If True, diagonal elements of `a` are assumed to be 1 and will not be referenced.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Unused by PyTensor. PyTensor will return nan if the operation fails.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
...
...
@@ -1018,7 +949,6 @@ def solve_triangular(
SolveTriangular
(
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
check_finite
=
check_finite
,
b_ndim
=
b_ndim
,
)
)(
a
,
b
)
...
...
@@ -1033,7 +963,6 @@ class Solve(SolveBase):
__props__
=
(
"assume_a"
,
"lower"
,
"check_finite"
,
"b_ndim"
,
"overwrite_a"
,
"overwrite_b"
,
...
...
@@ -1073,15 +1002,18 @@ class Solve(SolveBase):
def
perform
(
self
,
node
,
inputs
,
outputs
):
a
,
b
=
inputs
outputs
[
0
][
0
]
=
scipy_linalg
.
solve
(
a
=
a
,
b
=
b
,
lower
=
self
.
lower
,
check_finite
=
self
.
check_finite
,
assume_a
=
self
.
assume_a
,
overwrite_a
=
self
.
overwrite_a
,
overwrite_b
=
self
.
overwrite_b
,
)
try
:
outputs
[
0
][
0
]
=
scipy_linalg
.
solve
(
a
=
a
,
b
=
b
,
lower
=
self
.
lower
,
check_finite
=
False
,
assume_a
=
self
.
assume_a
,
overwrite_a
=
self
.
overwrite_a
,
overwrite_b
=
self
.
overwrite_b
,
)
except
np
.
linalg
.
LinAlgError
:
outputs
[
0
][
0
]
=
np
.
full
(
a
.
shape
,
np
.
nan
,
dtype
=
a
.
dtype
)
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
not
allowed_inplace_inputs
:
...
...
@@ -1152,10 +1084,8 @@ def solve(
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
overwrite_b : bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
check_finite : bool
Unused by PyTensor. PyTensor returns nan if the operation fails.
assume_a : str, optional
Valid entries are explained above.
transposed: bool, default False
...
...
@@ -1174,7 +1104,6 @@ def solve(
b
,
lower
=
lower
,
trans
=
transposed
,
check_finite
=
check_finite
,
b_ndim
=
b_ndim
,
)
...
...
@@ -1195,7 +1124,6 @@ def solve(
return
Blockwise
(
Solve
(
lower
=
lower
,
check_finite
=
check_finite
,
assume_a
=
assume_a
,
b_ndim
=
b_ndim
,
)
...
...
@@ -1779,7 +1707,6 @@ class QR(Op):
"overwrite_a"
,
"mode"
,
"pivoting"
,
"check_finite"
,
)
def
__init__
(
...
...
@@ -1787,12 +1714,10 @@ class QR(Op):
mode
:
Literal
[
"full"
,
"r"
,
"economic"
,
"raw"
]
=
"full"
,
overwrite_a
:
bool
=
False
,
pivoting
:
bool
=
False
,
check_finite
:
bool
=
False
,
):
self
.
mode
=
mode
self
.
overwrite_a
=
overwrite_a
self
.
pivoting
=
pivoting
self
.
check_finite
=
check_finite
self
.
destroy_map
=
{}
...
...
pytensor/xtensor/linalg.py
浏览文件 @
672a4829
from
collections.abc
import
Sequence
from
typing
import
Literal
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
from
pytensor.xtensor.type
import
as_xtensor
...
...
@@ -10,8 +9,7 @@ def cholesky(
x
,
lower
:
bool
=
True
,
*
,
check_finite
:
bool
=
False
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"raise"
,
check_finite
:
bool
=
True
,
dims
:
Sequence
[
str
],
):
"""Compute the Cholesky decomposition of an XTensorVariable.
...
...
@@ -22,22 +20,15 @@ def cholesky(
The input variable to decompose.
lower : bool, optional
Whether to return the lower triangular matrix. Default is True.
check_finite : bool, optional
Whether to check that the input is finite. Default is False.
on_error : {'raise', 'nan'}, optional
What to do if the input is not positive definite. If 'raise', an error is raised.
If 'nan', the output will contain NaNs. Default is 'raise'.
check_finite : bool
Unused by PyTensor. PyTensor will return nan if the operation fails.
dims : Sequence[str]
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
"""
if
len
(
dims
)
!=
2
:
raise
ValueError
(
f
"Cholesky needs two dims, got {len(dims)}"
)
core_op
=
Cholesky
(
lower
=
lower
,
check_finite
=
check_finite
,
on_error
=
on_error
,
)
core_op
=
Cholesky
(
lower
=
lower
)
core_dims
=
(
((
dims
[
0
],
dims
[
1
]),),
((
dims
[
0
],
dims
[
1
]),),
...
...
@@ -52,7 +43,7 @@ def solve(
dims
:
Sequence
[
str
],
assume_a
=
"gen"
,
lower
:
bool
=
False
,
check_finite
:
bool
=
Fals
e
,
check_finite
:
bool
=
Tru
e
,
):
"""Solve a system of linear equations using XTensorVariables.
...
...
@@ -75,8 +66,8 @@ def solve(
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
lower : bool, optional
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
check_finite : bool
, optional
Whether to check that the input is finite. Default is False
.
check_finite : bool
Unused by PyTensor. PyTensor will return nan if the operation fails
.
"""
a
,
b
=
as_xtensor
(
a
),
as_xtensor
(
b
)
input_core_dims
:
tuple
[
tuple
[
str
,
str
],
tuple
[
str
]
|
tuple
[
str
,
str
]]
...
...
@@ -98,9 +89,7 @@ def solve(
else
:
raise
ValueError
(
"Solve dims must have length 2 or 3"
)
core_op
=
Solve
(
b_ndim
=
b_ndim
,
assume_a
=
assume_a
,
lower
=
lower
,
check_finite
=
check_finite
)
core_op
=
Solve
(
b_ndim
=
b_ndim
,
assume_a
=
assume_a
,
lower
=
lower
)
x_op
=
XBlockwise
(
core_op
,
core_dims
=
(
input_core_dims
,
output_core_dims
),
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
672a4829
import
re
from
typing
import
Literal
import
numpy
as
np
...
...
@@ -36,70 +35,6 @@ floatX = config.floatX
rng
=
np
.
random
.
default_rng
(
42849
)
def
test_lamch
():
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.linalg.utils
import
_xlamch
@numba.njit
()
def
xlamch
(
kind
):
return
_xlamch
(
kind
)
lamch
=
get_lapack_funcs
(
"lamch"
,
(
np
.
array
([
0.0
],
dtype
=
floatX
),))
np
.
testing
.
assert_allclose
(
xlamch
(
"E"
),
lamch
(
"E"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"S"
),
lamch
(
"S"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"P"
),
lamch
(
"P"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"B"
),
lamch
(
"B"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"R"
),
lamch
(
"R"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"M"
),
lamch
(
"M"
))
@pytest.mark.parametrize
(
"ord_numba, ord_scipy"
,
[(
"F"
,
"fro"
),
(
"1"
,
1
),
(
"I"
,
np
.
inf
)]
)
def
test_xlange
(
ord_numba
,
ord_scipy
):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
@numba.njit
()
def
xlange
(
x
,
ord
):
return
_xlange
(
x
,
ord
)
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
np
.
testing
.
assert_allclose
(
xlange
(
x
,
ord_numba
),
linalg
.
norm
(
x
,
ord_scipy
))
@pytest.mark.parametrize
(
"ord_numba, ord_scipy"
,
[(
"1"
,
1
),
(
"I"
,
np
.
inf
)])
def
test_xgecon
(
ord_numba
,
ord_scipy
):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_xgecon
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
@numba.njit
()
def
gecon
(
x
,
norm
):
anorm
=
_xlange
(
x
,
norm
)
cond
,
info
=
_xgecon
(
x
,
anorm
,
norm
)
return
cond
,
info
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
rcond
,
info
=
gecon
(
x
,
norm
=
ord_numba
)
# Test against direct call to the underlying LAPACK functions
# Solution does **not** agree with 1 / np.linalg.cond(x) !
lange
,
gecon
=
get_lapack_funcs
((
"lange"
,
"gecon"
),
(
x
,))
norm
=
lange
(
ord_numba
,
x
)
rcond2
,
_
=
gecon
(
x
,
norm
,
norm
=
ord_numba
)
assert
info
==
0
np
.
testing
.
assert_allclose
(
rcond
,
rcond2
)
class
TestSolves
:
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower={x}"
)
@pytest.mark.parametrize
(
...
...
@@ -323,7 +258,7 @@ class TestSolves:
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
@pytest.mark.parametrize
(
"value"
,
[
np
.
nan
,
np
.
inf
])
def
test_solve_triangular_
raises
_on_nan_inf
(
self
,
value
):
def
test_solve_triangular_
does_not_raise
_on_nan_inf
(
self
,
value
):
A
=
pt
.
matrix
(
"A"
)
b
=
pt
.
matrix
(
"b"
)
...
...
@@ -335,11 +270,8 @@ class TestSolves:
A_tri
=
np
.
linalg
.
cholesky
(
A_sym
)
.
astype
(
floatX
)
b
=
np
.
full
((
5
,
1
),
value
)
.
astype
(
floatX
)
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
re
.
escape
(
"Non-numeric values"
),
):
f
(
A_tri
,
b
)
# Not checking everything is nan, because, with inf, LAPACK returns a mix of inf/nan, but does not set info != 0
assert
not
np
.
isfinite
(
f
(
A_tri
,
b
))
.
any
()
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower = {x}"
)
@pytest.mark.parametrize
(
...
...
@@ -567,10 +499,13 @@ class TestDecompositions:
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
x
=
x
.
T
.
dot
(
x
)
g
=
pt
.
linalg
.
cholesky
(
x
,
check_finite
=
True
)
with
pytest
.
warns
(
FutureWarning
):
g
=
pt
.
linalg
.
cholesky
(
x
,
check_finite
=
True
,
on_error
=
"raise"
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"Non-numeric values"
):
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"Matrix is not positive definite"
):
f
(
test_value
)
@pytest.mark.parametrize
(
"on_error"
,
[
"nan"
,
"raise"
])
...
...
@@ -578,13 +513,17 @@ class TestDecompositions:
test_value
=
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
if
on_error
==
"raise"
:
with
pytest
.
warns
(
FutureWarning
):
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
else
:
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
if
on_error
==
"raise"
:
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"
Input to cholesky
is not positive definite"
,
match
=
r"
Matrix
is not positive definite"
,
):
f
(
test_value
)
else
:
...
...
tests/tensor/linalg/test_rewriting.py
浏览文件 @
672a4829
...
...
@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
resx1
=
fn_opt
(
A_test
,
x0_test
)
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-4
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
@pytest.mark.parametrize
(
"assume_a, counter"
,
(
(
"gen"
,
LUOpCounter
),
(
"pos"
,
CholeskyOpCounter
),
),
)
def
test_decomposition_reused_preserves_check_finite
(
assume_a
,
counter
):
# Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name
=
reuse_decomposition_multiple_solves
.
__name__
A
=
tensor
(
"A"
,
shape
=
(
2
,
2
))
b1
=
tensor
(
"b1"
,
shape
=
(
2
,))
b2
=
tensor
(
"b2"
,
shape
=
(
2
,))
x1
=
solve
(
A
,
b1
,
assume_a
=
assume_a
,
check_finite
=
True
)
x2
=
solve
(
A
,
b2
,
assume_a
=
assume_a
,
check_finite
=
False
)
fn_opt
=
function
(
[
A
,
b1
,
b2
],
[
x1
,
x2
],
mode
=
get_default_mode
()
.
including
(
rewrite_name
)
)
opt_nodes
=
fn_opt
.
maker
.
fgraph
.
apply_nodes
assert
counter
.
count_vanilla_solve_nodes
(
opt_nodes
)
==
0
assert
counter
.
count_decomp_nodes
(
opt_nodes
)
==
1
assert
counter
.
count_solve_nodes
(
opt_nodes
)
==
2
# We should get an error if A or b1 is non finite
A_valid
=
np
.
array
([[
1
,
0
],
[
0
,
1
]],
dtype
=
A
.
type
.
dtype
)
b1_valid
=
np
.
array
([
1
,
1
],
dtype
=
b1
.
type
.
dtype
)
b2_valid
=
np
.
array
([
1
,
1
],
dtype
=
b2
.
type
.
dtype
)
assert
fn_opt
(
A_valid
,
b1_valid
,
b2_valid
)
# Fine
assert
fn_opt
(
A_valid
,
b1_valid
,
b2_valid
*
np
.
nan
)
# Should not raise (also fine on most LAPACK implementations?)
err_msg
=
(
"(array must not contain infs or NaNs"
r"|Non-numeric values \(nan or inf\))"
)
with
pytest
.
raises
((
ValueError
,
np
.
linalg
.
LinAlgError
),
match
=
err_msg
):
assert
fn_opt
(
A_valid
,
b1_valid
*
np
.
nan
,
b2_valid
)
with
pytest
.
raises
((
ValueError
,
np
.
linalg
.
LinAlgError
),
match
=
err_msg
):
assert
fn_opt
(
A_valid
*
np
.
nan
,
b1_valid
,
b2_valid
)
tests/tensor/test_slinalg.py
浏览文件 @
672a4829
...
...
@@ -74,9 +74,6 @@ def test_cholesky():
chol
=
Cholesky
(
lower
=
False
)(
x
)
ch_f
=
function
([
x
],
chol
)
check_upper_triangular
(
pd
,
ch_f
)
chol
=
Cholesky
(
lower
=
False
,
on_error
=
"nan"
)(
x
)
ch_f
=
function
([
x
],
chol
)
check_upper_triangular
(
pd
,
ch_f
)
def
test_cholesky_performance
(
benchmark
):
...
...
@@ -102,12 +99,15 @@ def test_cholesky_empty():
def
test_cholesky_indef
():
x
=
matrix
()
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
cholesky
(
x
))
with
pytest
.
warns
(
FutureWarning
):
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
out
)
with
pytest
.
raises
(
scipy
.
linalg
.
LinAlgError
):
chol_f
(
mat
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
cholesky
(
x
))
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
out
)
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
...
...
@@ -143,12 +143,16 @@ def test_cholesky_grad():
def
test_cholesky_grad_indef
():
x
=
matrix
()
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
grad
(
cholesky
(
x
)
.
sum
(),
[
x
]))
with
pytest
.
raises
(
scipy
.
linalg
.
LinAlgError
):
chol_f
(
mat
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
grad
(
cholesky
(
x
)
.
sum
(),
[
x
]))
with
pytest
.
warns
(
FutureWarning
):
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
grad
(
out
.
sum
(),
[
x
]),
mode
=
"FAST_RUN"
)
# original cholesky doesn't show up in the grad (if mode="FAST_RUN"), so it does not raise
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
grad
(
out
.
sum
(),
[
x
]))
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
...
...
@@ -237,7 +241,7 @@ class TestSolveBase:
y
=
self
.
SolveTest
(
b_ndim
=
2
)(
A
,
b
)
assert
(
y
.
__repr__
()
==
"SolveTest{lower=False,
check_finite=True,
b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
==
"SolveTest{lower=False, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
...
...
@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def
test_repr
(
self
):
assert
(
repr
(
CholeskySolve
(
lower
=
True
,
b_ndim
=
1
))
==
"CholeskySolve(lower=True,
check_finite=True,
b_ndim=1,overwrite_b=False)"
==
"CholeskySolve(lower=True,b_ndim=1,overwrite_b=False)"
)
def
test_infer_shape
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论