Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
672a4829
提交
672a4829
authored
1月 08, 2026
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
1月 11, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Do not raise in linalg Ops
上级
b2d8bc24
隐藏空白字符变更
内嵌
并排
正在显示
23 个修改的文件
包含
226 行增加
和
1132 行删除
+226
-1132
slinalg.py
pytensor/link/jax/dispatch/slinalg.py
+4
-10
_LAPACK.py
pytensor/link/numba/dispatch/linalg/_LAPACK.py
+0
-201
cholesky.py
...nsor/link/numba/dispatch/linalg/decomposition/cholesky.py
+11
-13
lu.py
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
+6
-15
lu_factor.py
...sor/link/numba/dispatch/linalg/decomposition/lu_factor.py
+4
-3
qr.py
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
+12
-34
cholesky.py
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
+9
-8
general.py
pytensor/link/numba/dispatch/linalg/solve/general.py
+12
-74
lu_solve.py
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
+5
-12
norm.py
pytensor/link/numba/dispatch/linalg/solve/norm.py
+0
-55
posdef.py
pytensor/link/numba/dispatch/linalg/solve/posdef.py
+6
-67
symmetric.py
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
+5
-59
triangular.py
pytensor/link/numba/dispatch/linalg/solve/triangular.py
+6
-8
tridiagonal.py
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
+16
-92
utils.py
pytensor/link/numba/dispatch/linalg/utils.py
+2
-64
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+23
-111
tensor_basic.py
pytensor/link/numba/dispatch/tensor_basic.py
+2
-2
rewriting.py
pytensor/tensor/_linalg/solve/rewriting.py
+4
-18
slinalg.py
pytensor/tensor/slinalg.py
+58
-133
linalg.py
pytensor/xtensor/linalg.py
+8
-19
test_slinalg.py
tests/link/numba/test_slinalg.py
+14
-75
test_rewriting.py
tests/tensor/linalg/test_rewriting.py
+0
-44
test_slinalg.py
tests/tensor/test_slinalg.py
+19
-15
没有找到文件。
pytensor/link/jax/dispatch/slinalg.py
浏览文件 @
672a4829
...
@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs):
...
@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs):
def
jax_funcify_SolveTriangular
(
op
,
**
kwargs
):
def
jax_funcify_SolveTriangular
(
op
,
**
kwargs
):
lower
=
op
.
lower
lower
=
op
.
lower
unit_diagonal
=
op
.
unit_diagonal
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
def
solve_triangular
(
A
,
b
):
def
solve_triangular
(
A
,
b
):
return
jax
.
scipy
.
linalg
.
solve_triangular
(
return
jax
.
scipy
.
linalg
.
solve_triangular
(
...
@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
...
@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
lower
=
lower
,
lower
=
lower
,
trans
=
0
,
# this is handled by explicitly transposing A, so it will always be 0 when we get to here.
trans
=
0
,
# this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal
=
unit_diagonal
,
unit_diagonal
=
unit_diagonal
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
)
)
return
solve_triangular
return
solve_triangular
...
@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs):
...
@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs):
def
jax_funcify_LU
(
op
,
**
kwargs
):
def
jax_funcify_LU
(
op
,
**
kwargs
):
permute_l
=
op
.
permute_l
permute_l
=
op
.
permute_l
p_indices
=
op
.
p_indices
p_indices
=
op
.
p_indices
check_finite
=
op
.
check_finite
if
p_indices
:
if
p_indices
:
raise
ValueError
(
"JAX does not support the p_indices argument"
)
raise
ValueError
(
"JAX does not support the p_indices argument"
)
def
lu
(
*
inputs
):
def
lu
(
*
inputs
):
return
jax
.
scipy
.
linalg
.
lu
(
return
jax
.
scipy
.
linalg
.
lu
(
*
inputs
,
permute_l
=
permute_l
,
check_finite
=
False
)
*
inputs
,
permute_l
=
permute_l
,
check_finite
=
check_finite
)
return
lu
return
lu
@jax_funcify.register
(
LUFactor
)
@jax_funcify.register
(
LUFactor
)
def
jax_funcify_LUFactor
(
op
,
**
kwargs
):
def
jax_funcify_LUFactor
(
op
,
**
kwargs
):
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
overwrite_a
=
op
.
overwrite_a
def
lu_factor
(
a
):
def
lu_factor
(
a
):
return
jax
.
scipy
.
linalg
.
lu_factor
(
return
jax
.
scipy
.
linalg
.
lu_factor
(
a
,
check_finite
=
check_finit
e
,
overwrite_a
=
overwrite_a
a
,
check_finite
=
Fals
e
,
overwrite_a
=
overwrite_a
)
)
return
lu_factor
return
lu_factor
...
@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs):
...
@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs):
@jax_funcify.register
(
CholeskySolve
)
@jax_funcify.register
(
CholeskySolve
)
def
jax_funcify_ChoSolve
(
op
,
**
kwargs
):
def
jax_funcify_ChoSolve
(
op
,
**
kwargs
):
lower
=
op
.
lower
lower
=
op
.
lower
check_finite
=
op
.
check_finite
overwrite_b
=
op
.
overwrite_b
overwrite_b
=
op
.
overwrite_b
def
cho_solve
(
c
,
b
):
def
cho_solve
(
c
,
b
):
return
jax
.
scipy
.
linalg
.
cho_solve
(
return
jax
.
scipy
.
linalg
.
cho_solve
(
(
c
,
lower
),
b
,
check_finite
=
check_finit
e
,
overwrite_b
=
overwrite_b
(
c
,
lower
),
b
,
check_finite
=
Fals
e
,
overwrite_b
=
overwrite_b
)
)
return
cho_solve
return
cho_solve
...
...
pytensor/link/numba/dispatch/linalg/_LAPACK.py
浏览文件 @
672a4829
...
@@ -263,122 +263,6 @@ class _LAPACK:
...
@@ -263,122 +263,6 @@ class _LAPACK:
return
potrs
return
potrs
@classmethod
def
numba_xlange
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve, but doesn't correspond to any Op in pytensor.
"""
kind
=
get_blas_kind
(
dtype
)
float_type
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
False
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
True
)
unique_func_name
=
f
"scipy.lapack.{kind}lange"
@numba_basic.numba_njit
def
get_lange_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"lange"
)
return
ptr
lange_function_type
=
types
.
FunctionType
(
float_type
(
nb_i32p
,
# NORM
nb_i32p
,
# M
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# WORK
)
)
@numba_basic.numba_njit
def
lange
(
NORM
,
M
,
N
,
A
,
LDA
,
WORK
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_lange_pointer
,
func_type_ref
=
lange_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
return
fn
(
NORM
,
M
,
N
,
A
,
LDA
,
WORK
)
return
lange
@classmethod
def
numba_xlamch
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Determine machine precision for floating point arithmetic.
"""
kind
=
get_blas_kind
(
dtype
)
float_type
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
False
)
unique_func_name
=
f
"scipy.lapack.{kind}lamch"
@numba_basic.numba_njit
def
get_lamch_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"lamch"
)
return
ptr
lamch_function_type
=
types
.
FunctionType
(
float_type
(
# Return type
nb_i32p
,
# CMACH
)
)
@numba_basic.numba_njit
def
lamch
(
CMACH
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_lamch_pointer
,
func_type_ref
=
lamch_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
res
=
fn
(
CMACH
)
return
res
return
lamch
@classmethod
def
numba_xgecon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}gecon"
@numba_basic.numba_njit
def
get_gecon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"gecon"
)
return
ptr
gecon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# NORM
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
gecon
(
NORM
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_gecon_pointer
,
func_type_ref
=
gecon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
NORM
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
gecon
@classmethod
@classmethod
def
numba_xgetrf
(
cls
,
dtype
)
->
CPUDispatcher
:
def
numba_xgetrf
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
"""
...
@@ -506,91 +390,6 @@ class _LAPACK:
...
@@ -506,91 +390,6 @@ class _LAPACK:
return
sysv
return
sysv
@classmethod
def
numba_xsycon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}sycon"
@numba_basic.numba_njit
def
get_sycon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"sycon"
)
return
ptr
sycon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# UPLO
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
nb_i32p
,
# IPIV
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
sycon
(
UPLO
,
N
,
A
,
LDA
,
IPIV
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_sycon_pointer
,
func_type_ref
=
sycon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
UPLO
,
N
,
A
,
LDA
,
IPIV
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
sycon
@classmethod
def
numba_xpocon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}pocon"
@numba_basic.numba_njit
def
get_pocon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"pocon"
)
return
ptr
pocon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# UPLO
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
pocon
(
UPLO
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_pocon_pointer
,
func_type_ref
=
pocon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
UPLO
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
pocon
@classmethod
@classmethod
def
numba_xposv
(
cls
,
dtype
)
->
CPUDispatcher
:
def
numba_xposv
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
"""
...
...
pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py
浏览文件 @
672a4829
...
@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
...
@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
):
return
(
return
linalg
.
cholesky
(
a
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
check_finite
=
False
)
linalg
.
cholesky
(
a
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
),
0
,
)
@overload
(
_cholesky
)
@overload
(
_cholesky
)
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
,
check_finite
=
True
):
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cholesky"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cholesky"
)
dtype
=
A
.
dtype
dtype
=
A
.
dtype
numba_potrf
=
_LAPACK
()
.
numba_xpotrf
(
dtype
)
numba_potrf
=
_LAPACK
()
.
numba_xpotrf
(
dtype
)
def
impl
(
A
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
def
impl
(
A
,
lower
=
False
,
overwrite_a
=
False
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
if
A
.
shape
[
-
2
]
!=
_N
:
if
A
.
shape
[
-
2
]
!=
_N
:
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
...
@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
...
@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
INFO
,
INFO
,
)
)
if
int_ptr_to_val
(
INFO
)
!=
0
:
A_copy
=
np
.
full_like
(
A_copy
,
np
.
nan
)
return
A_copy
if
lower
:
if
lower
:
for
j
in
range
(
1
,
_N
):
for
j
in
range
(
1
,
_N
):
for
i
in
range
(
j
):
for
i
in
range
(
j
):
...
@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
...
@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
for
i
in
range
(
j
+
1
,
_N
):
for
i
in
range
(
j
+
1
,
_N
):
A_copy
[
i
,
j
]
=
0.0
A_copy
[
i
,
j
]
=
0.0
info_int
=
int_ptr_to_val
(
INFO
)
if
transposed
:
if
transposed
:
return
A_copy
.
T
,
info_int
return
A_copy
.
T
return
A_copy
,
info_int
else
:
return
A_copy
return
impl
return
impl
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
浏览文件 @
672a4829
...
@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
...
@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
def
_lu_1
(
def
_lu_1
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
Literal
[
True
],
permute_l
:
Literal
[
True
],
check_finite
:
bool
,
p_indices
:
Literal
[
False
],
p_indices
:
Literal
[
False
],
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
@@ -52,7 +51,7 @@ def _lu_1(
...
@@ -52,7 +51,7 @@ def _lu_1(
return
linalg
.
lu
(
# type: ignore[no-any-return]
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
)
)
...
@@ -61,7 +60,6 @@ def _lu_1(
...
@@ -61,7 +60,6 @@ def _lu_1(
def
_lu_2
(
def
_lu_2
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
Literal
[
False
],
permute_l
:
Literal
[
False
],
check_finite
:
bool
,
p_indices
:
Literal
[
True
],
p_indices
:
Literal
[
True
],
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
...
@@ -74,7 +72,7 @@ def _lu_2(
...
@@ -74,7 +72,7 @@ def _lu_2(
return
linalg
.
lu
(
# type: ignore[no-any-return]
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
)
)
...
@@ -83,7 +81,6 @@ def _lu_2(
...
@@ -83,7 +81,6 @@ def _lu_2(
def
_lu_3
(
def
_lu_3
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
Literal
[
False
],
permute_l
:
Literal
[
False
],
check_finite
:
bool
,
p_indices
:
Literal
[
False
],
p_indices
:
Literal
[
False
],
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
@@ -96,7 +93,7 @@ def _lu_3(
...
@@ -96,7 +93,7 @@ def _lu_3(
return
linalg
.
lu
(
# type: ignore[no-any-return]
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
)
)
...
@@ -106,11 +103,10 @@ def _lu_3(
...
@@ -106,11 +103,10 @@ def _lu_3(
def
lu_impl_1
(
def
lu_impl_1
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[
)
->
Callable
[
[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
]:
]:
"""
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...
@@ -123,7 +119,6 @@ def lu_impl_1(
...
@@ -123,7 +119,6 @@ def lu_impl_1(
def
impl
(
def
impl
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
@@ -137,10 +132,9 @@ def lu_impl_1(
...
@@ -137,10 +132,9 @@ def lu_impl_1(
def
lu_impl_2
(
def
lu_impl_2
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
)
->
Callable
[[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
"""
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
...
@@ -153,7 +147,6 @@ def lu_impl_2(
...
@@ -153,7 +147,6 @@ def lu_impl_2(
def
impl
(
def
impl
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
...
@@ -169,11 +162,10 @@ def lu_impl_2(
...
@@ -169,11 +162,10 @@ def lu_impl_2(
def
lu_impl_3
(
def
lu_impl_3
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[
)
->
Callable
[
[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
]:
]:
"""
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...
@@ -186,7 +178,6 @@ def lu_impl_3(
...
@@ -186,7 +178,6 @@ def lu_impl_3(
def
impl
(
def
impl
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
浏览文件 @
672a4829
...
@@ -79,11 +79,12 @@ def lu_factor_impl(
...
@@ -79,11 +79,12 @@ def lu_factor_impl(
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_factor"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_factor"
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
A_copy
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
A_copy
,
IPIV
,
info
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
IPIV
-=
1
# LAPACK uses 1-based indexing, convert to 0-based
IPIV
-=
1
# LAPACK uses 1-based indexing, convert to 0-based
if
INFO
!=
0
:
if
info
!=
0
:
raise
np
.
linalg
.
LinAlgError
(
"LU decomposition failed"
)
A_copy
=
np
.
full_like
(
A_copy
,
np
.
nan
)
return
A_copy
,
IPIV
return
A_copy
,
IPIV
return
impl
return
impl
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
浏览文件 @
672a4829
...
@@ -228,7 +228,6 @@ def _qr_full_pivot(
...
@@ -228,7 +228,6 @@ def _qr_full_pivot(
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
pivoting
:
Literal
[
True
]
=
True
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -243,7 +242,7 @@ def _qr_full_pivot(
...
@@ -243,7 +242,7 @@ def _qr_full_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -253,7 +252,6 @@ def _qr_full_no_pivot(
...
@@ -253,7 +252,6 @@ def _qr_full_no_pivot(
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
pivoting
:
Literal
[
False
]
=
False
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -267,7 +265,7 @@ def _qr_full_no_pivot(
...
@@ -267,7 +265,7 @@ def _qr_full_no_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -277,7 +275,6 @@ def _qr_r_pivot(
...
@@ -277,7 +275,6 @@ def _qr_r_pivot(
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
pivoting
:
Literal
[
True
]
=
True
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -291,7 +288,7 @@ def _qr_r_pivot(
...
@@ -291,7 +288,7 @@ def _qr_r_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -301,7 +298,6 @@ def _qr_r_no_pivot(
...
@@ -301,7 +298,6 @@ def _qr_r_no_pivot(
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
pivoting
:
Literal
[
False
]
=
False
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -315,7 +311,7 @@ def _qr_r_no_pivot(
...
@@ -315,7 +311,7 @@ def _qr_r_no_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -325,7 +321,6 @@ def _qr_raw_no_pivot(
...
@@ -325,7 +321,6 @@ def _qr_raw_no_pivot(
mode
:
Literal
[
"raw"
]
=
"raw"
,
mode
:
Literal
[
"raw"
]
=
"raw"
,
pivoting
:
Literal
[
False
]
=
False
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -339,7 +334,7 @@ def _qr_raw_no_pivot(
...
@@ -339,7 +334,7 @@ def _qr_raw_no_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -351,7 +346,6 @@ def _qr_raw_pivot(
...
@@ -351,7 +346,6 @@ def _qr_raw_pivot(
mode
:
Literal
[
"raw"
]
=
"raw"
,
mode
:
Literal
[
"raw"
]
=
"raw"
,
pivoting
:
Literal
[
True
]
=
True
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -365,7 +359,7 @@ def _qr_raw_pivot(
...
@@ -365,7 +359,7 @@ def _qr_raw_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -373,9 +367,7 @@ def _qr_raw_pivot(
...
@@ -373,9 +367,7 @@ def _qr_raw_pivot(
@overload
(
_qr_full_pivot
)
@overload
(
_qr_full_pivot
)
def
qr_full_pivot_impl
(
def
qr_full_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
x
,
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
dtype
=
x
.
dtype
...
@@ -395,7 +387,6 @@ def qr_full_pivot_impl(
...
@@ -395,7 +387,6 @@ def qr_full_pivot_impl(
mode
=
"full"
,
mode
=
"full"
,
pivoting
=
True
,
pivoting
=
True
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
@@ -529,7 +520,7 @@ def qr_full_pivot_impl(
...
@@ -529,7 +520,7 @@ def qr_full_pivot_impl(
@overload
(
_qr_full_no_pivot
)
@overload
(
_qr_full_no_pivot
)
def
qr_full_no_pivot_impl
(
def
qr_full_no_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
x
,
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
...
@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl(
...
@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl(
mode
=
"full"
,
mode
=
"full"
,
pivoting
=
False
,
pivoting
=
False
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl(
...
@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl(
@overload
(
_qr_r_pivot
)
@overload
(
_qr_r_pivot
)
def
qr_r_pivot_impl
(
def
qr_r_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
x
,
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
dtype
=
x
.
dtype
...
@@ -658,7 +646,6 @@ def qr_r_pivot_impl(
...
@@ -658,7 +646,6 @@ def qr_r_pivot_impl(
mode
=
"r"
,
mode
=
"r"
,
pivoting
=
True
,
pivoting
=
True
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
@@ -720,9 +707,7 @@ def qr_r_pivot_impl(
...
@@ -720,9 +707,7 @@ def qr_r_pivot_impl(
@overload
(
_qr_r_no_pivot
)
@overload
(
_qr_r_no_pivot
)
def
qr_r_no_pivot_impl
(
def
qr_r_no_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
x
,
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
dtype
=
x
.
dtype
...
@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl(
...
@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl(
mode
=
"r"
,
mode
=
"r"
,
pivoting
=
False
,
pivoting
=
False
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl(
...
@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl(
@overload
(
_qr_raw_no_pivot
)
@overload
(
_qr_raw_no_pivot
)
def
qr_raw_no_pivot_impl
(
def
qr_raw_no_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
x
,
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
dtype
=
x
.
dtype
...
@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl(
...
@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl(
mode
=
"raw"
,
mode
=
"raw"
,
pivoting
=
False
,
pivoting
=
False
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl(
...
@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl(
@overload
(
_qr_raw_pivot
)
@overload
(
_qr_raw_pivot
)
def
qr_raw_pivot_impl
(
def
qr_raw_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
x
,
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
...
@@ -880,7 +859,6 @@ def qr_raw_pivot_impl(
...
@@ -880,7 +859,6 @@ def qr_raw_pivot_impl(
mode
=
"raw"
,
mode
=
"raw"
,
pivoting
=
True
,
pivoting
=
True
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
浏览文件 @
672a4829
...
@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import (
...
@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
)
def
_cho_solve
(
def
_cho_solve
(
C
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_b
:
bool
):
C
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
):
"""
"""
Solve a positive-definite linear system using the Cholesky decomposition.
Solve a positive-definite linear system using the Cholesky decomposition.
"""
"""
return
linalg
.
cho_solve
(
return
linalg
.
cho_solve
(
(
C
,
lower
),
b
=
B
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
(
C
,
lower
),
b
=
B
,
overwrite_b
=
overwrite_b
,
check_finite
=
False
,
)
)
@overload
(
_cho_solve
)
@overload
(
_cho_solve
)
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
C
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cho_solve"
)
_check_linalg_matrix
(
C
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cho_solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"cho_solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"cho_solve"
)
...
@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
...
@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
dtype
=
C
.
dtype
dtype
=
C
.
dtype
numba_potrs
=
_LAPACK
()
.
numba_xpotrs
(
dtype
)
numba_potrs
=
_LAPACK
()
.
numba_xpotrs
(
dtype
)
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
):
_solve_check_input_shapes
(
C
,
B
)
_solve_check_input_shapes
(
C
,
B
)
_N
=
np
.
int32
(
C
.
shape
[
-
1
])
_N
=
np
.
int32
(
C
.
shape
[
-
1
])
...
@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
...
@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
INFO
,
INFO
,
)
)
_solve_check
(
_N
,
int_ptr_to_val
(
INFO
))
if
int_ptr_to_val
(
INFO
)
!=
0
:
B_copy
=
np
.
full_like
(
B_copy
,
np
.
nan
)
if
B_is_1d
:
if
B_is_1d
:
return
B_copy
[
...
,
0
]
return
B_copy
[
...
,
0
]
...
...
pytensor/link/numba/dispatch/linalg/solve/general.py
浏览文件 @
672a4829
...
@@ -3,82 +3,24 @@ from collections.abc import Callable
...
@@ -3,82 +3,24 @@ from collections.abc import Callable
import
numpy
as
np
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.solve.lu_solve
import
_getrs
from
pytensor.link.numba.dispatch.linalg.solve.lu_solve
import
_getrs
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_solve_check
,
)
)
def
_xgecon
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return
# type: ignore
@overload
(
_xgecon
)
def
xgecon_impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
Callable
[[
np
.
ndarray
,
float
,
str
],
tuple
[
np
.
ndarray
,
int
]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"gecon"
)
dtype
=
A
.
dtype
numba_gecon
=
_LAPACK
()
.
numba_xgecon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
A_NORM
=
np
.
array
(
A_norm
,
dtype
=
dtype
)
NORM
=
val_to_int_ptr
(
ord
(
norm
))
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
4
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
1
)
numba_gecon
(
NORM
,
N
,
A_copy
.
ctypes
,
LDA
,
A_NORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_gen
(
def
_solve_gen
(
A
:
np
.
ndarray
,
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
):
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
...
@@ -89,7 +31,7 @@ def _solve_gen(
...
@@ -89,7 +31,7 @@ def _solve_gen(
lower
=
lower
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
assume_a
=
"gen"
,
assume_a
=
"gen"
,
transposed
=
transposed
,
transposed
=
transposed
,
)
)
...
@@ -102,9 +44,8 @@ def solve_gen_impl(
...
@@ -102,9 +44,8 @@ def solve_gen_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
@@ -116,7 +57,6 @@ def solve_gen_impl(
...
@@ -116,7 +57,6 @@ def solve_gen_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
...
@@ -127,20 +67,18 @@ def solve_gen_impl(
...
@@ -127,20 +67,18 @@ def solve_gen_impl(
A
=
A
.
T
A
=
A
.
T
transposed
=
not
transposed
transposed
=
not
transposed
order
=
"I"
if
transposed
else
"1"
LU
,
IPIV
,
INFO1
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
norm
=
_xlange
(
A
,
order
=
order
)
N
=
A
.
shape
[
1
]
LU
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
_solve_check
(
N
,
INFO
)
X
,
INFO
=
_getrs
(
X
,
INFO2
=
_getrs
(
LU
=
LU
,
B
=
B
,
IPIV
=
IPIV
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
LU
=
LU
,
B
=
B
,
IPIV
=
IPIV
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
,
)
)
_solve_check
(
N
,
INFO
)
RCOND
,
INFO
=
_xgecon
(
LU
,
norm
,
"1"
)
if
INFO1
!=
0
or
INFO2
!=
0
:
_solve_check
(
N
,
INFO
,
True
,
RCOND
)
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
return
X
...
...
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
浏览文件 @
672a4829
...
@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
...
@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
_trans_char_to_int
,
)
)
...
@@ -107,14 +106,11 @@ def _lu_solve(
...
@@ -107,14 +106,11 @@ def _lu_solve(
b
:
np
.
ndarray
,
b
:
np
.
ndarray
,
trans
:
_Trans
,
trans
:
_Trans
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
):
):
"""
"""
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
"""
"""
return
linalg
.
lu_solve
(
return
linalg
.
lu_solve
(
lu_and_piv
,
b
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
lu_and_piv
,
b
,
trans
=
trans
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
)
@overload
(
_lu_solve
)
@overload
(
_lu_solve
)
...
@@ -123,8 +119,7 @@ def lu_solve_impl(
...
@@ -123,8 +119,7 @@ def lu_solve_impl(
b
:
np
.
ndarray
,
b
:
np
.
ndarray
,
trans
:
_Trans
,
trans
:
_Trans
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
ensure_lapack
()
lu
,
_piv
=
lu_and_piv
lu
,
_piv
=
lu_and_piv
_check_linalg_matrix
(
lu
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_solve"
)
_check_linalg_matrix
(
lu
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_solve"
)
...
@@ -137,13 +132,11 @@ def lu_solve_impl(
...
@@ -137,13 +132,11 @@ def lu_solve_impl(
b
:
np
.
ndarray
,
b
:
np
.
ndarray
,
trans
:
_Trans
,
trans
:
_Trans
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
n
=
np
.
int32
(
lu
.
shape
[
0
]
)
X
,
info
=
_getrs
(
LU
=
lu
,
B
=
b
,
IPIV
=
piv
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
X
,
INFO
=
_getrs
(
LU
=
lu
,
B
=
b
,
IPIV
=
piv
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
if
info
!=
0
:
X
=
np
.
full_like
(
X
,
np
.
nan
)
_solve_check
(
n
,
INFO
)
return
X
return
X
...
...
pytensor/link/numba/dispatch/linalg/solve/norm.py
deleted
100644 → 0
浏览文件 @
b2d8bc24
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_xlange
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
float
:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return
# type: ignore
@overload
(
_xlange
)
def
xlange_impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
Callable
[[
np
.
ndarray
,
str
],
float
]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"norm"
)
dtype
=
A
.
dtype
numba_lange
=
_LAPACK
()
.
numba_xlange
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
):
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
NORM
=
(
val_to_int_ptr
(
ord
(
order
))
if
order
is
not
None
else
val_to_int_ptr
(
ord
(
"1"
))
)
WORK
=
np
.
empty
(
_M
,
dtype
=
dtype
)
# type: ignore
result
=
numba_lange
(
NORM
,
M
,
N
,
A_copy
.
ctypes
,
LDA
,
WORK
.
ctypes
)
return
result
return
impl
pytensor/link/numba/dispatch/linalg/solve/posdef.py
浏览文件 @
672a4829
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val
,
int_ptr_to_val
,
val_to_int_ptr
,
val_to_int_ptr
,
)
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
)
...
@@ -27,8 +25,6 @@ def _posv(
...
@@ -27,8 +25,6 @@ def _posv(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
...
@@ -43,10 +39,8 @@ def posv_impl(
...
@@ -43,10 +39,8 @@ def posv_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
],
]:
]:
ensure_lapack
()
ensure_lapack
()
...
@@ -62,8 +56,6 @@ def posv_impl(
...
@@ -62,8 +56,6 @@ def posv_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_solve_check_input_shapes
(
A
,
B
)
_solve_check_input_shapes
(
A
,
B
)
...
@@ -115,60 +107,12 @@ def posv_impl(
...
@@ -115,60 +107,12 @@ def posv_impl(
return
impl
return
impl
def
_pocon
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return
# type: ignore
@overload
(
_pocon
)
def
pocon_impl
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"pocon"
)
dtype
=
A
.
dtype
numba_pocon
=
_LAPACK
()
.
numba_xpocon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
anorm
:
float
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
))
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
3
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_pocon
(
UPLO
,
N
,
A_copy
.
ctypes
,
LDA
,
ANORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_psd
(
def
_solve_psd
(
A
:
np
.
ndarray
,
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
):
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
...
@@ -179,7 +123,7 @@ def _solve_psd(
...
@@ -179,7 +123,7 @@ def _solve_psd(
lower
=
lower
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
transposed
=
transposed
,
transposed
=
transposed
,
assume_a
=
"pos"
,
assume_a
=
"pos"
,
)
)
...
@@ -192,9 +136,8 @@ def solve_psd_impl(
...
@@ -192,9 +136,8 @@ def solve_psd_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
@@ -206,18 +149,14 @@ def solve_psd_impl(
...
@@ -206,18 +149,14 @@ def solve_psd_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
_solve_check_input_shapes
(
A
,
B
)
C
,
x
,
info
=
_posv
(
_C
,
x
,
info
=
_posv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
,
check_finite
,
transposed
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
rcond
,
info
=
_pocon
(
C
,
_xlange
(
A
))
if
info
!=
0
:
_solve_check
(
A
.
shape
[
-
1
],
info
=
info
,
lamch
=
True
,
rcond
=
rcond
)
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
return
x
...
...
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
浏览文件 @
672a4829
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val
,
int_ptr_to_val
,
val_to_int_ptr
,
val_to_int_ptr
,
)
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
)
...
@@ -121,61 +119,12 @@ def sysv_impl(
...
@@ -121,61 +119,12 @@ def sysv_impl(
return
impl
return
impl
def
_sycon
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return
# type: ignore
@overload
(
_sycon
)
def
sycon_impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"sycon"
)
dtype
=
A
.
dtype
numba_sycon
=
_LAPACK
()
.
numba_xsycon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
UPLO
=
val_to_int_ptr
(
ord
(
"U"
))
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
2
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_sycon
(
UPLO
,
N
,
A_copy
.
ctypes
,
LDA
,
ipiv
.
ctypes
,
ANORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_symmetric
(
def
_solve_symmetric
(
A
:
np
.
ndarray
,
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
):
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
...
@@ -186,7 +135,7 @@ def _solve_symmetric(
...
@@ -186,7 +135,7 @@ def _solve_symmetric(
lower
=
lower
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
assume_a
=
"sym"
,
assume_a
=
"sym"
,
transposed
=
transposed
,
transposed
=
transposed
,
)
)
...
@@ -199,9 +148,8 @@ def solve_symmetric_impl(
...
@@ -199,9 +148,8 @@ def solve_symmetric_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
@@ -213,16 +161,14 @@ def solve_symmetric_impl(
...
@@ -213,16 +161,14 @@ def solve_symmetric_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
_solve_check_input_shapes
(
A
,
B
)
lu
,
x
,
ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
_lu
,
x
,
_ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
rcond
,
info
=
_sycon
(
lu
,
ipiv
,
_xlange
(
A
,
order
=
"I"
))
if
info
!=
0
:
_solve_check
(
A
.
shape
[
-
1
],
info
,
True
,
rcond
)
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
return
x
...
...
pytensor/link/numba/dispatch/linalg/solve/triangular.py
浏览文件 @
672a4829
...
@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import (
...
@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
_trans_char_to_int
,
)
)
def
_solve_triangular
(
def
_solve_triangular
(
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
b_ndim
=
1
,
overwrite_b
=
False
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
overwrite_b
=
False
):
):
"""
"""
Thin wrapper around scipy.linalg.solve_triangular.
Thin wrapper around scipy.linalg.solve_triangular.
...
@@ -39,11 +38,12 @@ def _solve_triangular(
...
@@ -39,11 +38,12 @@ def _solve_triangular(
lower
=
lower
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
False
,
)
)
@overload
(
_solve_triangular
)
@overload
(
_solve_triangular
)
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
overwrite_b
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve_triangular"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve_triangular"
)
...
@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
...
@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
"This function is not expected to work with complex numbers yet"
"This function is not expected to work with complex numbers yet"
)
)
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
overwrite_b
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
_solve_check_input_shapes
(
A
,
B
)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d
=
B
.
ndim
==
1
B_is_1d
=
B
.
ndim
==
1
if
A
.
flags
.
f_contiguous
or
(
A
.
flags
.
c_contiguous
and
trans
in
(
0
,
1
)):
if
A
.
flags
.
f_contiguous
or
(
A
.
flags
.
c_contiguous
and
trans
in
(
0
,
1
)):
...
@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
...
@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
LDB
,
LDB
,
INFO
,
INFO
,
)
)
if
int_ptr_to_val
(
INFO
)
!=
0
:
_solve_check
(
int_ptr_to_val
(
LDA
),
int_ptr_to_val
(
INFO
)
)
B_copy
=
np
.
full_like
(
B_copy
,
np
.
nan
)
if
B_is_1d
:
if
B_is_1d
:
return
B_copy
[
...
,
0
]
return
B_copy
[
...
,
0
]
...
...
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
浏览文件 @
672a4829
...
@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
...
@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
_trans_char_to_int
,
)
)
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
...
@@ -202,83 +201,12 @@ def gttrs_impl(
...
@@ -202,83 +201,12 @@ def gttrs_impl(
return
impl
return
impl
def
_gtcon
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
tuple
[
ndarray
,
int
]:
"""Placeholder for computing the condition number of a tridiagonal system."""
return
# type: ignore
@overload
(
_gtcon
)
def
gtcon_impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
Callable
[
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
float
,
str
],
tuple
[
ndarray
,
int
]
]:
ensure_lapack
()
_check_linalg_matrix
(
dl
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
d
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
du
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
du2
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_dtypes_match
((
dl
,
d
,
du
,
du2
),
func_name
=
"gtcon"
)
_check_linalg_matrix
(
ipiv
,
ndim
=
1
,
dtype
=
int32
,
func_name
=
"gtcon"
)
dtype
=
d
.
dtype
numba_gtcon
=
_LAPACK
()
.
numba_xgtcon
(
dtype
)
def
impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
tuple
[
ndarray
,
int
]:
n
=
np
.
int32
(
d
.
shape
[
-
1
])
rcond
=
np
.
empty
(
1
,
dtype
=
dtype
)
work
=
np
.
empty
(
2
*
n
,
dtype
=
dtype
)
iwork
=
np
.
empty
(
n
,
dtype
=
np
.
int32
)
info
=
val_to_int_ptr
(
0
)
numba_gtcon
(
val_to_int_ptr
(
ord
(
norm
)),
val_to_int_ptr
(
n
),
dl
.
ctypes
,
d
.
ctypes
,
du
.
ctypes
,
du2
.
ctypes
,
ipiv
.
ctypes
,
np
.
array
(
anorm
,
dtype
=
dtype
)
.
ctypes
,
rcond
.
ctypes
,
work
.
ctypes
,
iwork
.
ctypes
,
info
,
)
return
rcond
,
int_ptr_to_val
(
info
)
return
impl
def
_solve_tridiagonal
(
def
_solve_tridiagonal
(
a
:
ndarray
,
a
:
ndarray
,
b
:
ndarray
,
b
:
ndarray
,
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
):
):
"""
"""
...
@@ -290,7 +218,7 @@ def _solve_tridiagonal(
...
@@ -290,7 +218,7 @@ def _solve_tridiagonal(
lower
=
lower
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
transposed
=
transposed
,
transposed
=
transposed
,
assume_a
=
"tridiagonal"
,
assume_a
=
"tridiagonal"
,
)
)
...
@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl(
...
@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
ndarray
,
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
ndarray
]:
)
->
Callable
[[
ndarray
,
ndarray
,
bool
,
bool
,
bool
,
bool
],
ndarray
]:
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl(
...
@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
ndarray
:
)
->
ndarray
:
n
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
_solve_check_input_shapes
(
A
,
B
)
norm
=
"1"
if
transposed
:
if
transposed
:
A
=
A
.
T
A
=
A
.
T
dl
,
d
,
du
=
np
.
diag
(
A
,
-
1
),
np
.
diag
(
A
,
0
),
np
.
diag
(
A
,
1
)
dl
,
d
,
du
=
np
.
diag
(
A
,
-
1
),
np
.
diag
(
A
,
0
),
np
.
diag
(
A
,
1
)
anorm
=
tridiagonal_norm
(
du
,
d
,
dl
)
dl
,
d
,
du
,
du2
,
ipiv
,
info1
=
_gttrf
(
dl
,
d
,
du
,
du2
,
IPIV
,
INFO
=
_gttrf
(
dl
,
d
,
du
,
overwrite_dl
=
True
,
overwrite_d
=
True
,
overwrite_du
=
True
dl
,
d
,
du
,
overwrite_dl
=
True
,
overwrite_d
=
True
,
overwrite_du
=
True
)
)
_solve_check
(
n
,
INFO
)
X
,
INFO
=
_gttrs
(
X
,
info2
=
_gttrs
(
dl
,
d
,
du
,
du2
,
IPIV
,
B
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
dl
,
d
,
du
,
du2
,
ipiv
,
B
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
)
)
_solve_check
(
n
,
INFO
)
RCOND
,
INFO
=
_gtcon
(
dl
,
d
,
du
,
du2
,
IPIV
,
anorm
,
norm
)
if
info1
!=
0
or
info2
!=
0
:
_solve_check
(
n
,
INFO
,
True
,
RCOND
)
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
return
X
...
@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
...
@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
)
)
return
dl
,
d
,
du
,
du2
,
ipiv
return
dl
,
d
,
du
,
du2
,
ipiv
cache_
key
=
1
cache_
version
=
2
return
lu_factor_tridiagonal
,
cache_
key
return
lu_factor_tridiagonal
,
cache_
version
@register_funcify_default_op_cache_key
(
SolveLUFactorTridiagonal
)
@register_funcify_default_op_cache_key
(
SolveLUFactorTridiagonal
)
...
@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
...
@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
ipiv
=
ipiv
.
astype
(
np
.
int32
)
ipiv
=
ipiv
.
astype
(
np
.
int32
)
if
cast_b
:
if
cast_b
:
b
=
b
.
astype
(
out_dtype
)
b
=
b
.
astype
(
out_dtype
)
x
,
_
=
_gttrs
(
x
,
info
=
_gttrs
(
dl
,
dl
,
d
,
d
,
du
,
du
,
...
@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal(
...
@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
trans
=
transposed
,
trans
=
transposed
,
)
)
if
info
!=
0
:
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
return
x
cache_
key
=
1
cache_
version
=
2
return
solve_lu_factor_tridiagonal
,
cache_
key
return
solve_lu_factor_tridiagonal
,
cache_
version
pytensor/link/numba/dispatch/linalg/utils.py
浏览文件 @
672a4829
from
collections.abc
import
Callable
,
Sequence
from
collections.abc
import
Sequence
import
numba
import
numba
from
numba.core
import
types
from
numba.core
import
types
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numpy.linalg
import
LinAlgError
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
val_to_int_ptr
,
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
(
inline
=
"always"
)
...
@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
...
@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
if
first_dtype
!=
other_dtype
:
if
first_dtype
!=
other_dtype
:
msg
=
f
"{func_name} only supported for matching dtypes, got {dtypes}"
msg
=
f
"{func_name} only supported for matching dtypes, got {dtypes}"
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_solve_check
(
n
,
info
,
lamch
=
False
,
rcond
=
None
):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if
info
<
0
:
# TODO: figure out how to do an fstring here
msg
=
"LAPACK reported an illegal value in input"
raise
ValueError
(
msg
)
elif
0
<
info
:
raise
LinAlgError
(
"Matrix is singular."
)
if
lamch
:
E
=
_xlamch
(
"E"
)
if
rcond
<
E
:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print
(
# noqa: T201
"Ill-conditioned matrix, rcond="
,
rcond
,
", result may not be accurate."
)
def
_xlamch
(
kind
:
str
=
"E"
):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload
(
_xlamch
)
def
xlamch_impl
(
kind
:
str
=
"E"
)
->
Callable
[[
str
],
float
]:
"""
Compute the machine precision for a given floating point type.
"""
from
pytensor
import
config
ensure_lapack
()
w_type
=
_get_underlying_float
(
config
.
floatX
)
if
w_type
==
"float32"
:
dtype
=
types
.
float32
elif
w_type
==
"float64"
:
dtype
=
types
.
float64
else
:
raise
NotImplementedError
(
"Unsupported dtype"
)
numba_lamch
=
_LAPACK
()
.
numba_xlamch
(
dtype
)
def
impl
(
kind
:
str
=
"E"
)
->
float
:
KIND
=
val_to_int_ptr
(
ord
(
kind
))
return
numba_lamch
(
KIND
)
# type: ignore
return
impl
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
672a4829
...
@@ -58,8 +58,6 @@ def numba_funcify_Cholesky(op, node, **kwargs):
...
@@ -58,8 +58,6 @@ def numba_funcify_Cholesky(op, node, **kwargs):
"""
"""
lower
=
op
.
lower
lower
=
op
.
lower
overwrite_a
=
op
.
overwrite_a
overwrite_a
=
op
.
overwrite_a
check_finite
=
op
.
check_finite
on_error
=
op
.
on_error
inp_dtype
=
node
.
inputs
[
0
]
.
type
.
numpy_dtype
inp_dtype
=
node
.
inputs
[
0
]
.
type
.
numpy_dtype
if
inp_dtype
.
kind
==
"c"
:
if
inp_dtype
.
kind
==
"c"
:
...
@@ -77,30 +75,11 @@ def numba_funcify_Cholesky(op, node, **kwargs):
...
@@ -77,30 +75,11 @@ def numba_funcify_Cholesky(op, node, **kwargs):
if
discrete_inp
:
if
discrete_inp
:
a
=
a
.
astype
(
out_dtype
)
a
=
a
.
astype
(
out_dtype
)
elif
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to cholesky"
)
res
,
info
=
_cholesky
(
a
,
lower
,
overwrite_a
,
check_finite
)
if
on_error
==
"raise"
:
if
info
>
0
:
raise
np
.
linalg
.
LinAlgError
(
"Input to cholesky is not positive definite"
)
if
info
<
0
:
raise
ValueError
(
'LAPACK reported an illegal value in input on entry to "POTRF."'
)
else
:
if
info
!=
0
:
res
=
np
.
full_like
(
res
,
np
.
nan
)
return
res
return
_cholesky
(
a
,
lower
,
overwrite_a
)
cache_
key
=
1
cache_
version
=
2
return
cholesky
,
cache_
key
return
cholesky
,
cache_
version
@register_funcify_default_op_cache_key
(
PivotToPermutations
)
@register_funcify_default_op_cache_key
(
PivotToPermutations
)
...
@@ -116,8 +95,8 @@ def pivot_to_permutation(op, node, **kwargs):
...
@@ -116,8 +95,8 @@ def pivot_to_permutation(op, node, **kwargs):
return
np
.
argsort
(
p_inv
)
return
np
.
argsort
(
p_inv
)
cache_
key
=
1
cache_
version
=
2
return
numba_pivot_to_permutation
,
cache_
key
return
numba_pivot_to_permutation
,
cache_
version
@register_funcify_default_op_cache_key
(
LU
)
@register_funcify_default_op_cache_key
(
LU
)
...
@@ -131,7 +110,6 @@ def numba_funcify_LU(op, node, **kwargs):
...
@@ -131,7 +110,6 @@ def numba_funcify_LU(op, node, **kwargs):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
permute_l
=
op
.
permute_l
permute_l
=
op
.
permute_l
check_finite
=
op
.
check_finite
p_indices
=
op
.
p_indices
p_indices
=
op
.
p_indices
overwrite_a
=
op
.
overwrite_a
overwrite_a
=
op
.
overwrite_a
...
@@ -151,17 +129,11 @@ def numba_funcify_LU(op, node, **kwargs):
...
@@ -151,17 +129,11 @@ def numba_funcify_LU(op, node, **kwargs):
if
discrete_inp
:
if
discrete_inp
:
a
=
a
.
astype
(
out_dtype
)
a
=
a
.
astype
(
out_dtype
)
elif
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to lu"
)
if
p_indices
:
if
p_indices
:
res
=
_lu_1
(
res
=
_lu_1
(
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
)
)
...
@@ -169,7 +141,6 @@ def numba_funcify_LU(op, node, **kwargs):
...
@@ -169,7 +141,6 @@ def numba_funcify_LU(op, node, **kwargs):
res
=
_lu_2
(
res
=
_lu_2
(
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
)
)
...
@@ -177,15 +148,14 @@ def numba_funcify_LU(op, node, **kwargs):
...
@@ -177,15 +148,14 @@ def numba_funcify_LU(op, node, **kwargs):
res
=
_lu_3
(
res
=
_lu_3
(
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
)
)
return
res
return
res
cache_
key
=
1
cache_
version
=
2
return
lu
,
cache_
key
return
lu
,
cache_
version
@register_funcify_default_op_cache_key
(
LUFactor
)
@register_funcify_default_op_cache_key
(
LUFactor
)
...
@@ -198,7 +168,6 @@ def numba_funcify_LUFactor(op, node, **kwargs):
...
@@ -198,7 +168,6 @@ def numba_funcify_LUFactor(op, node, **kwargs):
print
(
"LUFactor requires casting discrete input to float"
)
# noqa: T201
print
(
"LUFactor requires casting discrete input to float"
)
# noqa: T201
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
overwrite_a
=
op
.
overwrite_a
@numba_basic.numba_njit
@numba_basic.numba_njit
...
@@ -211,18 +180,13 @@ def numba_funcify_LUFactor(op, node, **kwargs):
...
@@ -211,18 +180,13 @@ def numba_funcify_LUFactor(op, node, **kwargs):
if
discrete_inp
:
if
discrete_inp
:
a
=
a
.
astype
(
out_dtype
)
a
=
a
.
astype
(
out_dtype
)
elif
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to cholesky"
)
LU
,
piv
=
_lu_factor
(
a
,
overwrite_a
)
LU
,
piv
=
_lu_factor
(
a
,
overwrite_a
)
return
LU
,
piv
return
LU
,
piv
cache_
key
=
1
cache_
version
=
2
return
lu_factor
,
cache_
key
return
lu_factor
,
cache_
version
@register_funcify_default_op_cache_key
(
BlockDiagonal
)
@register_funcify_default_op_cache_key
(
BlockDiagonal
)
...
@@ -288,8 +252,8 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
...
@@ -288,8 +252,8 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
globals
()
|
{
"np"
:
np
},
globals
()
|
{
"np"
:
np
},
)
)
cache_
key
=
1
cache_
version
=
1
return
numba_basic
.
numba_njit
(
block_diag
),
cache_
key
return
numba_basic
.
numba_njit
(
block_diag
),
cache_
version
@register_funcify_default_op_cache_key
(
Solve
)
@register_funcify_default_op_cache_key
(
Solve
)
...
@@ -306,12 +270,9 @@ def numba_funcify_Solve(op, node, **kwargs):
...
@@ -306,12 +270,9 @@ def numba_funcify_Solve(op, node, **kwargs):
if
must_cast_B
and
config
.
compiler_verbose
:
if
must_cast_B
and
config
.
compiler_verbose
:
print
(
"Solve requires casting second input `b`"
)
# noqa: T201
print
(
"Solve requires casting second input `b`"
)
# noqa: T201
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
overwrite_a
=
op
.
overwrite_a
assume_a
=
op
.
assume_a
assume_a
=
op
.
assume_a
lower
=
op
.
lower
lower
=
op
.
lower
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
overwrite_a
=
op
.
overwrite_a
overwrite_b
=
op
.
overwrite_b
overwrite_b
=
op
.
overwrite_b
transposed
=
False
# TODO: Solve doesnt currently allow the transposed argument
transposed
=
False
# TODO: Solve doesnt currently allow the transposed argument
...
@@ -344,30 +305,18 @@ def numba_funcify_Solve(op, node, **kwargs):
...
@@ -344,30 +305,18 @@ def numba_funcify_Solve(op, node, **kwargs):
a
=
a
.
astype
(
out_dtype
)
a
=
a
.
astype
(
out_dtype
)
if
must_cast_B
:
if
must_cast_B
:
b
=
b
.
astype
(
out_dtype
)
b
=
b
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to solve"
)
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to solve"
)
res
=
solve_fn
(
a
,
b
,
lower
,
overwrite_a
,
overwrite_b
,
check_finite
,
transposed
)
return
res
cache_key
=
1
return
solve_fn
(
a
,
b
,
lower
,
overwrite_a
,
overwrite_b
,
transposed
)
return
solve
,
cache_key
cache_version
=
2
return
solve
,
cache_version
@register_funcify_default_op_cache_key
(
SolveTriangular
)
@register_funcify_default_op_cache_key
(
SolveTriangular
)
def
numba_funcify_SolveTriangular
(
op
,
node
,
**
kwargs
):
def
numba_funcify_SolveTriangular
(
op
,
node
,
**
kwargs
):
lower
=
op
.
lower
lower
=
op
.
lower
unit_diagonal
=
op
.
unit_diagonal
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
overwrite_b
=
op
.
overwrite_b
overwrite_b
=
op
.
overwrite_b
b_ndim
=
op
.
b_ndim
A_dtype
,
b_dtype
=
(
i
.
type
.
numpy_dtype
for
i
in
node
.
inputs
)
A_dtype
,
b_dtype
=
(
i
.
type
.
numpy_dtype
for
i
in
node
.
inputs
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
...
@@ -389,37 +338,24 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
...
@@ -389,37 +338,24 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
a
=
a
.
astype
(
out_dtype
)
a
=
a
.
astype
(
out_dtype
)
if
must_cast_B
:
if
must_cast_B
:
b
=
b
.
astype
(
out_dtype
)
b
=
b
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
return
_solve_triangular
(
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to solve_triangular"
)
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to solve_triangular"
)
res
=
_solve_triangular
(
a
,
a
,
b
,
b
,
trans
=
0
,
# transposing is handled explicitly on the graph, so we never use this argument
trans
=
0
,
# transposing is handled explicitly on the graph, so we never use this argument
lower
=
lower
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
b_ndim
=
b_ndim
,
)
)
return
res
cache_version
=
2
return
solve_triangular
,
cache_version
cache_key
=
1
return
solve_triangular
,
cache_key
@register_funcify_default_op_cache_key
(
CholeskySolve
)
@register_funcify_default_op_cache_key
(
CholeskySolve
)
def
numba_funcify_CholeskySolve
(
op
,
node
,
**
kwargs
):
def
numba_funcify_CholeskySolve
(
op
,
node
,
**
kwargs
):
lower
=
op
.
lower
lower
=
op
.
lower
overwrite_b
=
op
.
overwrite_b
overwrite_b
=
op
.
overwrite_b
check_finite
=
op
.
check_finite
c_dtype
,
b_dtype
=
(
i
.
type
.
numpy_dtype
for
i
in
node
.
inputs
)
c_dtype
,
b_dtype
=
(
i
.
type
.
numpy_dtype
for
i
in
node
.
inputs
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
...
@@ -439,36 +375,24 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
...
@@ -439,36 +375,24 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
return
np
.
zeros
(
b
.
shape
,
dtype
=
out_dtype
)
return
np
.
zeros
(
b
.
shape
,
dtype
=
out_dtype
)
if
must_cast_c
:
if
must_cast_c
:
c
=
c
.
astype
(
out_dtype
)
c
=
c
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
c
),
np
.
isnan
(
c
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to cho_solve"
)
if
must_cast_b
:
if
must_cast_b
:
b
=
b
.
astype
(
out_dtype
)
b
=
b
.
astype
(
out_dtype
)
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to cho_solve"
)
return
_cho_solve
(
return
_cho_solve
(
c
,
c
,
b
,
b
,
lower
=
lower
,
lower
=
lower
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
)
)
cache_
key
=
1
cache_
version
=
2
return
cho_solve
,
cache_
key
return
cho_solve
,
cache_
version
@register_funcify_default_op_cache_key
(
QR
)
@register_funcify_default_op_cache_key
(
QR
)
def
numba_funcify_QR
(
op
,
node
,
**
kwargs
):
def
numba_funcify_QR
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
mode
=
op
.
mode
check_finite
=
op
.
check_finite
pivoting
=
op
.
pivoting
pivoting
=
op
.
pivoting
overwrite_a
=
op
.
overwrite_a
overwrite_a
=
op
.
overwrite_a
...
@@ -481,12 +405,6 @@ def numba_funcify_QR(op, node, **kwargs):
...
@@ -481,12 +405,6 @@ def numba_funcify_QR(op, node, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
qr
(
a
):
def
qr
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to qr"
)
if
integer_input
:
if
integer_input
:
a
=
a
.
astype
(
out_dtype
)
a
=
a
.
astype
(
out_dtype
)
...
@@ -496,7 +414,6 @@ def numba_funcify_QR(op, node, **kwargs):
...
@@ -496,7 +414,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
)
return
Q
,
R
,
P
return
Q
,
R
,
P
...
@@ -506,7 +423,6 @@ def numba_funcify_QR(op, node, **kwargs):
...
@@ -506,7 +423,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
)
return
Q
,
R
return
Q
,
R
...
@@ -516,7 +432,6 @@ def numba_funcify_QR(op, node, **kwargs):
...
@@ -516,7 +432,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
)
return
R
,
P
return
R
,
P
...
@@ -526,7 +441,6 @@ def numba_funcify_QR(op, node, **kwargs):
...
@@ -526,7 +441,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
)
return
R
return
R
...
@@ -536,7 +450,6 @@ def numba_funcify_QR(op, node, **kwargs):
...
@@ -536,7 +450,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
)
return
H
,
tau
,
R
,
P
return
H
,
tau
,
R
,
P
...
@@ -546,7 +459,6 @@ def numba_funcify_QR(op, node, **kwargs):
...
@@ -546,7 +459,6 @@ def numba_funcify_QR(op, node, **kwargs):
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
)
return
H
,
tau
,
R
return
H
,
tau
,
R
...
@@ -555,5 +467,5 @@ def numba_funcify_QR(op, node, **kwargs):
...
@@ -555,5 +467,5 @@ def numba_funcify_QR(op, node, **kwargs):
f
"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
f
"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
)
)
cache_
key
=
1
cache_
version
=
2
return
qr
,
cache_
key
return
qr
,
cache_
version
pytensor/link/numba/dispatch/tensor_basic.py
浏览文件 @
672a4829
...
@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
...
@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
out
[
...
,
i
]
=
new_entry
out
[
...
,
i
]
=
new_entry
return
out
return
out
cache_
key
=
1
cache_
version
=
1
return
extract_diag
,
cache_
key
return
extract_diag
,
cache_
version
@register_funcify_default_op_cache_key
(
Eye
)
@register_funcify_default_op_cache_key
(
Eye
)
...
...
pytensor/tensor/_linalg/solve/rewriting.py
浏览文件 @
672a4829
...
@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so
...
@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
def
decompose_A
(
A
,
assume_a
,
check_finite
,
lower
):
def
decompose_A
(
A
,
assume_a
,
lower
):
if
assume_a
==
"gen"
:
if
assume_a
==
"gen"
:
return
lu_factor
(
A
,
check_finite
=
check_finite
)
return
lu_factor
(
A
)
elif
assume_a
==
"tridiagonal"
:
elif
assume_a
==
"tridiagonal"
:
# We didn't implement check_finite for tridiagonal LU factorization
return
tridiagonal_lu_factor
(
A
)
return
tridiagonal_lu_factor
(
A
)
elif
assume_a
==
"pos"
:
elif
assume_a
==
"pos"
:
return
cholesky
(
A
,
lower
=
lower
,
check_finite
=
check_finite
)
return
cholesky
(
A
,
lower
=
lower
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -36,7 +35,6 @@ def solve_decomposed_system(
...
@@ -36,7 +35,6 @@ def solve_decomposed_system(
A_decomp
,
b
,
transposed
=
False
,
lower
=
False
,
*
,
core_solve_op
:
Solve
A_decomp
,
b
,
transposed
=
False
,
lower
=
False
,
*
,
core_solve_op
:
Solve
):
):
b_ndim
=
core_solve_op
.
b_ndim
b_ndim
=
core_solve_op
.
b_ndim
check_finite
=
core_solve_op
.
check_finite
assume_a
=
core_solve_op
.
assume_a
assume_a
=
core_solve_op
.
assume_a
if
assume_a
==
"gen"
:
if
assume_a
==
"gen"
:
...
@@ -45,10 +43,8 @@ def solve_decomposed_system(
...
@@ -45,10 +43,8 @@ def solve_decomposed_system(
b
,
b
,
b_ndim
=
b_ndim
,
b_ndim
=
b_ndim
,
trans
=
transposed
,
trans
=
transposed
,
check_finite
=
check_finite
,
)
)
elif
assume_a
==
"tridiagonal"
:
elif
assume_a
==
"tridiagonal"
:
# We didn't implement check_finite for tridiagonal LU solve
return
tridiagonal_lu_solve
(
return
tridiagonal_lu_solve
(
A_decomp
,
A_decomp
,
b
,
b
,
...
@@ -61,7 +57,6 @@ def solve_decomposed_system(
...
@@ -61,7 +57,6 @@ def solve_decomposed_system(
(
A_decomp
,
lower
),
(
A_decomp
,
lower
),
b
,
b
,
b_ndim
=
b_ndim
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
,
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps(
...
@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps(
):
):
return
None
return
None
# If any Op had check_finite=True, we also do it for the LU decomposition
check_finite_decomp
=
False
for
client
,
_
in
A_solve_clients_and_transpose
:
if
client
.
op
.
core_op
.
check_finite
:
check_finite_decomp
=
True
break
lower
=
node
.
op
.
core_op
.
lower
lower
=
node
.
op
.
core_op
.
lower
A_decomp
=
decompose_A
(
A_decomp
=
decompose_A
(
A
,
assume_a
=
assume_a
,
lower
=
lower
)
A
,
assume_a
=
assume_a
,
check_finite
=
check_finite_decomp
,
lower
=
lower
)
replacements
=
{}
replacements
=
{}
for
client
,
transposed
in
A_solve_clients_and_transpose
:
for
client
,
transposed
in
A_solve_clients_and_transpose
:
...
...
pytensor/tensor/slinalg.py
浏览文件 @
672a4829
...
@@ -6,7 +6,7 @@ from typing import Literal, cast
...
@@ -6,7 +6,7 @@ from typing import Literal, cast
import
numpy
as
np
import
numpy
as
np
import
scipy.linalg
as
scipy_linalg
import
scipy.linalg
as
scipy_linalg
from
scipy.linalg
import
LinAlgError
,
LinAlgWarning
,
get_lapack_funcs
from
scipy.linalg
import
get_lapack_funcs
import
pytensor
import
pytensor
from
pytensor
import
ifelse
from
pytensor
import
ifelse
...
@@ -14,7 +14,7 @@ from pytensor import tensor as pt
...
@@ -14,7 +14,7 @@ from pytensor import tensor as pt
from
pytensor.gradient
import
DisconnectedType
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.raise_op
import
Assert
from
pytensor.raise_op
import
Assert
,
CheckAndRaise
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
math
as
ptm
from
pytensor.tensor
import
math
as
ptm
...
@@ -32,22 +32,16 @@ logger = logging.getLogger(__name__)
...
@@ -32,22 +32,16 @@ logger = logging.getLogger(__name__)
class
Cholesky
(
Op
):
class
Cholesky
(
Op
):
# TODO: LAPACK wrapper with in-place behavior, for solve also
# TODO: LAPACK wrapper with in-place behavior, for solve also
__props__
=
(
"lower"
,
"
check_finite"
,
"on_error"
,
"
overwrite_a"
)
__props__
=
(
"lower"
,
"overwrite_a"
)
gufunc_signature
=
"(m,m)->(m,m)"
gufunc_signature
=
"(m,m)->(m,m)"
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
lower
:
bool
=
True
,
lower
:
bool
=
True
,
check_finite
:
bool
=
False
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"raise"
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
):
):
self
.
lower
=
lower
self
.
lower
=
lower
self
.
check_finite
=
check_finite
if
on_error
not
in
(
"raise"
,
"nan"
):
raise
ValueError
(
'on_error must be one of "raise" or ""nan"'
)
self
.
on_error
=
on_error
self
.
overwrite_a
=
overwrite_a
self
.
overwrite_a
=
overwrite_a
if
self
.
overwrite_a
:
if
self
.
overwrite_a
:
...
@@ -77,13 +71,6 @@ class Cholesky(Op):
...
@@ -77,13 +71,6 @@ class Cholesky(Op):
out
[
0
]
=
np
.
empty_like
(
x
,
dtype
=
potrf
.
dtype
)
out
[
0
]
=
np
.
empty_like
(
x
,
dtype
=
potrf
.
dtype
)
return
return
if
self
.
check_finite
and
not
np
.
isfinite
(
x
)
.
all
():
if
self
.
on_error
==
"nan"
:
out
[
0
]
=
np
.
full
(
x
.
shape
,
np
.
nan
,
dtype
=
potrf
.
dtype
)
return
else
:
raise
ValueError
(
"array must not contain infs or NaNs"
)
# Squareness check
# Squareness check
if
x
.
shape
[
0
]
!=
x
.
shape
[
1
]:
if
x
.
shape
[
0
]
!=
x
.
shape
[
1
]:
raise
ValueError
(
raise
ValueError
(
...
@@ -104,17 +91,8 @@ class Cholesky(Op):
...
@@ -104,17 +91,8 @@ class Cholesky(Op):
c
,
info
=
potrf
(
x
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
clean
=
True
)
c
,
info
=
potrf
(
x
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
clean
=
True
)
if
info
!=
0
:
if
info
!=
0
:
if
self
.
on_error
==
"nan"
:
c
[
...
]
=
np
.
nan
out
[
0
]
=
np
.
full
(
x
.
shape
,
np
.
nan
,
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
)
out
[
0
]
=
c
elif
info
>
0
:
raise
scipy_linalg
.
LinAlgError
(
f
"{info}-th leading minor of the array is not positive definite"
)
elif
info
<
0
:
raise
ValueError
(
f
"LAPACK reported an illegal value in {-info}-th argument "
f
'on entry to "POTRF".'
)
else
:
else
:
# Transpose result if input was transposed
# Transpose result if input was transposed
out
[
0
]
=
c
.
T
if
c_contiguous_input
else
c
out
[
0
]
=
c
.
T
if
c_contiguous_input
else
c
...
@@ -135,13 +113,6 @@ class Cholesky(Op):
...
@@ -135,13 +113,6 @@ class Cholesky(Op):
dz
=
gradients
[
0
]
dz
=
gradients
[
0
]
chol_x
=
outputs
[
0
]
chol_x
=
outputs
[
0
]
# Replace the cholesky decomposition with 1 if there are nans
# or solve_upper_triangular will throw a ValueError.
if
self
.
on_error
==
"nan"
:
ok
=
~
ptm
.
any
(
ptm
.
isnan
(
chol_x
))
chol_x
=
ptb
.
switch
(
ok
,
chol_x
,
1
)
dz
=
ptb
.
switch
(
ok
,
dz
,
1
)
# deal with upper triangular by converting to lower triangular
# deal with upper triangular by converting to lower triangular
if
not
self
.
lower
:
if
not
self
.
lower
:
chol_x
=
chol_x
.
T
chol_x
=
chol_x
.
T
...
@@ -165,10 +136,7 @@ class Cholesky(Op):
...
@@ -165,10 +136,7 @@ class Cholesky(Op):
else
:
else
:
grad
=
ptb
.
triu
(
s
+
s
.
T
)
-
ptb
.
diag
(
ptb
.
diagonal
(
s
))
grad
=
ptb
.
triu
(
s
+
s
.
T
)
-
ptb
.
diag
(
ptb
.
diagonal
(
s
))
if
self
.
on_error
==
"nan"
:
return
[
grad
]
return
[
ptb
.
switch
(
ok
,
grad
,
np
.
nan
)]
else
:
return
[
grad
]
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
not
allowed_inplace_inputs
:
if
not
allowed_inplace_inputs
:
...
@@ -182,9 +150,9 @@ def cholesky(
...
@@ -182,9 +150,9 @@ def cholesky(
x
:
"TensorLike"
,
x
:
"TensorLike"
,
lower
:
bool
=
True
,
lower
:
bool
=
True
,
*
,
*
,
check_finite
:
bool
=
Fals
e
,
check_finite
:
bool
=
Tru
e
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"
raise
"
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"
nan
"
,
):
):
"""
"""
Return a triangular matrix square root of positive semi-definite `x`.
Return a triangular matrix square root of positive semi-definite `x`.
...
@@ -196,8 +164,8 @@ def cholesky(
...
@@ -196,8 +164,8 @@ def cholesky(
x: tensor_like
x: tensor_like
lower : bool, default=True
lower : bool, default=True
Whether to return the lower or upper cholesky factor
Whether to return the lower or upper cholesky factor
check_finite : bool
, default=False
check_finite : bool
Whether to check that the input matrix contains only finite number
s.
Unused by PyTensor. PyTensor will return nan if the operation fail
s.
overwrite_a: bool, ignored
overwrite_a: bool, ignored
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
for consistency with scipy.linalg.cholesky.
for consistency with scipy.linalg.cholesky.
...
@@ -228,10 +196,19 @@ def cholesky(
...
@@ -228,10 +196,19 @@ def cholesky(
assert np.allclose(L_value @ L_value.T, x_value)
assert np.allclose(L_value @ L_value.T, x_value)
"""
"""
res
=
Blockwise
(
Cholesky
(
lower
=
lower
))(
x
)
return
Blockwise
(
if
on_error
==
"raise"
:
Cholesky
(
lower
=
lower
,
on_error
=
on_error
,
check_finite
=
check_finite
)
# For back-compatibility
)(
x
)
warnings
.
warn
(
'Cholesky on_raise == "raise" is deprecated. The operation will return nan when in fails. Setting this argument will fail in the future'
,
FutureWarning
,
)
res
=
CheckAndRaise
(
np
.
linalg
.
LinAlgError
,
"Matrix is not positive definite"
)(
res
,
~
ptm
.
isnan
(
res
)
.
any
()
)
return
res
class
SolveBase
(
Op
):
class
SolveBase
(
Op
):
...
@@ -239,7 +216,6 @@ class SolveBase(Op):
...
@@ -239,7 +216,6 @@ class SolveBase(Op):
__props__
:
tuple
[
str
,
...
]
=
(
__props__
:
tuple
[
str
,
...
]
=
(
"lower"
,
"lower"
,
"check_finite"
,
"b_ndim"
,
"b_ndim"
,
"overwrite_a"
,
"overwrite_a"
,
"overwrite_b"
,
"overwrite_b"
,
...
@@ -249,13 +225,11 @@ class SolveBase(Op):
...
@@ -249,13 +225,11 @@ class SolveBase(Op):
self
,
self
,
*
,
*
,
lower
=
False
,
lower
=
False
,
check_finite
=
True
,
b_ndim
,
b_ndim
,
overwrite_a
=
False
,
overwrite_a
=
False
,
overwrite_b
=
False
,
overwrite_b
=
False
,
):
):
self
.
lower
=
lower
self
.
lower
=
lower
self
.
check_finite
=
check_finite
assert
b_ndim
in
(
1
,
2
)
assert
b_ndim
in
(
1
,
2
)
self
.
b_ndim
=
b_ndim
self
.
b_ndim
=
b_ndim
...
@@ -358,7 +332,6 @@ def _default_b_ndim(b, b_ndim):
...
@@ -358,7 +332,6 @@ def _default_b_ndim(b, b_ndim):
class
CholeskySolve
(
SolveBase
):
class
CholeskySolve
(
SolveBase
):
__props__
=
(
__props__
=
(
"lower"
,
"lower"
,
"check_finite"
,
"b_ndim"
,
"b_ndim"
,
"overwrite_b"
,
"overwrite_b"
,
)
)
...
@@ -366,7 +339,6 @@ class CholeskySolve(SolveBase):
...
@@ -366,7 +339,6 @@ class CholeskySolve(SolveBase):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
if
kwargs
.
get
(
"overwrite_a"
,
False
):
if
kwargs
.
get
(
"overwrite_a"
,
False
):
raise
ValueError
(
"overwrite_a is not supported for CholeskySolve"
)
raise
ValueError
(
"overwrite_a is not supported for CholeskySolve"
)
kwargs
.
setdefault
(
"lower"
,
True
)
super
()
.
__init__
(
**
kwargs
)
super
()
.
__init__
(
**
kwargs
)
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
...
@@ -387,9 +359,6 @@ class CholeskySolve(SolveBase):
...
@@ -387,9 +359,6 @@ class CholeskySolve(SolveBase):
(
potrs
,)
=
get_lapack_funcs
((
"potrs"
,),
(
c
,
b
))
(
potrs
,)
=
get_lapack_funcs
((
"potrs"
,),
(
c
,
b
))
if
self
.
check_finite
and
not
(
np
.
isfinite
(
c
)
.
all
()
and
np
.
isfinite
(
b
)
.
all
()):
raise
ValueError
(
"array must not contain infs or NaNs"
)
if
c
.
shape
[
0
]
!=
c
.
shape
[
1
]:
if
c
.
shape
[
0
]
!=
c
.
shape
[
1
]:
raise
ValueError
(
"The factored matrix c is not square."
)
raise
ValueError
(
"The factored matrix c is not square."
)
if
c
.
shape
[
1
]
!=
b
.
shape
[
0
]:
if
c
.
shape
[
1
]
!=
b
.
shape
[
0
]:
...
@@ -402,7 +371,7 @@ class CholeskySolve(SolveBase):
...
@@ -402,7 +371,7 @@ class CholeskySolve(SolveBase):
x
,
info
=
potrs
(
c
,
b
,
lower
=
self
.
lower
,
overwrite_b
=
self
.
overwrite_b
)
x
,
info
=
potrs
(
c
,
b
,
lower
=
self
.
lower
,
overwrite_b
=
self
.
overwrite_b
)
if
info
!=
0
:
if
info
!=
0
:
raise
ValueError
(
f
"illegal value in {-info}th argument of internal potrs"
)
x
[
...
]
=
np
.
nan
output_storage
[
0
][
0
]
=
x
output_storage
[
0
][
0
]
=
x
...
@@ -423,7 +392,6 @@ def cho_solve(
...
@@ -423,7 +392,6 @@ def cho_solve(
c_and_lower
:
tuple
[
TensorLike
,
bool
],
c_and_lower
:
tuple
[
TensorLike
,
bool
],
b
:
TensorLike
,
b
:
TensorLike
,
*
,
*
,
check_finite
:
bool
=
True
,
b_ndim
:
int
|
None
=
None
,
b_ndim
:
int
|
None
=
None
,
):
):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
...
@@ -434,33 +402,26 @@ def cho_solve(
...
@@ -434,33 +402,26 @@ def cho_solve(
Cholesky factorization of a, as given by cho_factor
Cholesky factorization of a, as given by cho_factor
b : TensorLike
b : TensorLike
Right-hand side
Right-hand side
check_finite : bool, optional
check_finite : bool
Whether to check that the input matrices contain only finite numbers.
Unused by PyTensor. PyTensor will return nan if the operation fails.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
This will influence how batched dimensions are interpreted.
"""
"""
A
,
lower
=
c_and_lower
A
,
lower
=
c_and_lower
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
return
Blockwise
(
return
Blockwise
(
CholeskySolve
(
lower
=
lower
,
b_ndim
=
b_ndim
))(
A
,
b
)
CholeskySolve
(
lower
=
lower
,
check_finite
=
check_finite
,
b_ndim
=
b_ndim
)
)(
A
,
b
)
class
LU
(
Op
):
class
LU
(
Op
):
"""Decompose a matrix into lower and upper triangular matrices."""
"""Decompose a matrix into lower and upper triangular matrices."""
__props__
=
(
"permute_l"
,
"overwrite_a"
,
"
check_finite"
,
"
p_indices"
)
__props__
=
(
"permute_l"
,
"overwrite_a"
,
"p_indices"
)
def
__init__
(
def
__init__
(
self
,
*
,
permute_l
=
False
,
overwrite_a
=
False
,
p_indices
=
False
):
self
,
*
,
permute_l
=
False
,
overwrite_a
=
False
,
check_finite
=
True
,
p_indices
=
False
):
if
permute_l
and
p_indices
:
if
permute_l
and
p_indices
:
raise
ValueError
(
"Only one of permute_l and p_indices can be True"
)
raise
ValueError
(
"Only one of permute_l and p_indices can be True"
)
self
.
permute_l
=
permute_l
self
.
permute_l
=
permute_l
self
.
check_finite
=
check_finite
self
.
p_indices
=
p_indices
self
.
p_indices
=
p_indices
self
.
overwrite_a
=
overwrite_a
self
.
overwrite_a
=
overwrite_a
...
@@ -523,7 +484,6 @@ class LU(Op):
...
@@ -523,7 +484,6 @@ class LU(Op):
A
,
A
,
permute_l
=
self
.
permute_l
,
permute_l
=
self
.
permute_l
,
overwrite_a
=
self
.
overwrite_a
,
overwrite_a
=
self
.
overwrite_a
,
check_finite
=
self
.
check_finite
,
p_indices
=
self
.
p_indices
,
p_indices
=
self
.
p_indices
,
)
)
...
@@ -563,7 +523,7 @@ class LU(Op):
...
@@ -563,7 +523,7 @@ class LU(Op):
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
# We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass
# We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass
P_or_indices
,
L
,
U
=
lu
(
# type: ignore
P_or_indices
,
L
,
U
=
lu
(
# type: ignore
A
,
permute_l
=
False
,
check_finite
=
self
.
check_finite
,
p_indices
=
False
A
,
permute_l
=
False
,
p_indices
=
False
)
)
else
:
else
:
...
@@ -621,8 +581,8 @@ def lu(
...
@@ -621,8 +581,8 @@ def lu(
permute_l: bool
permute_l: bool
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
be returned in this case, and PL will not be lower triangular.
be returned in this case, and PL will not be lower triangular.
check_finite: bool
check_finite
: bool
Whether to check that the input matrix contains only finite number
s.
Unused by PyTensor. PyTensor will return nan if the operation fail
s.
p_indices: bool
p_indices: bool
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
itself.
itself.
...
@@ -640,9 +600,7 @@ def lu(
...
@@ -640,9 +600,7 @@ def lu(
return
cast
(
return
cast
(
tuple
[
TensorVariable
,
TensorVariable
,
TensorVariable
]
tuple
[
TensorVariable
,
TensorVariable
,
TensorVariable
]
|
tuple
[
TensorVariable
,
TensorVariable
],
|
tuple
[
TensorVariable
,
TensorVariable
],
Blockwise
(
Blockwise
(
LU
(
permute_l
=
permute_l
,
p_indices
=
p_indices
))(
a
),
LU
(
permute_l
=
permute_l
,
p_indices
=
p_indices
,
check_finite
=
check_finite
)
)(
a
),
)
)
...
@@ -680,12 +638,11 @@ def pivot_to_permutation(p: TensorLike, inverse=False):
...
@@ -680,12 +638,11 @@ def pivot_to_permutation(p: TensorLike, inverse=False):
class
LUFactor
(
Op
):
class
LUFactor
(
Op
):
__props__
=
(
"overwrite_a"
,
"check_finite"
)
__props__
=
(
"overwrite_a"
,)
gufunc_signature
=
"(m,m)->(m,m),(m)"
gufunc_signature
=
"(m,m)->(m,m),(m)"
def
__init__
(
self
,
*
,
overwrite_a
=
False
,
check_finite
=
True
):
def
__init__
(
self
,
*
,
overwrite_a
=
False
):
self
.
overwrite_a
=
overwrite_a
self
.
overwrite_a
=
overwrite_a
self
.
check_finite
=
check_finite
if
self
.
overwrite_a
:
if
self
.
overwrite_a
:
self
.
destroy_map
=
{
1
:
[
0
]}
self
.
destroy_map
=
{
1
:
[
0
]}
...
@@ -723,21 +680,10 @@ class LUFactor(Op):
...
@@ -723,21 +680,10 @@ class LUFactor(Op):
outputs
[
1
][
0
]
=
np
.
array
([],
dtype
=
np
.
int32
)
outputs
[
1
][
0
]
=
np
.
array
([],
dtype
=
np
.
int32
)
return
return
if
self
.
check_finite
and
not
np
.
isfinite
(
A
)
.
all
():
raise
ValueError
(
"array must not contain infs or NaNs"
)
(
getrf
,)
=
get_lapack_funcs
((
"getrf"
,),
(
A
,))
(
getrf
,)
=
get_lapack_funcs
((
"getrf"
,),
(
A
,))
LU
,
p
,
info
=
getrf
(
A
,
overwrite_a
=
self
.
overwrite_a
)
LU
,
p
,
info
=
getrf
(
A
,
overwrite_a
=
self
.
overwrite_a
)
if
info
<
0
:
if
info
!=
0
:
raise
ValueError
(
LU
[
...
]
=
np
.
nan
f
"illegal value in {-info}th argument of internal getrf (lu_factor)"
)
if
info
>
0
:
warnings
.
warn
(
f
"Diagonal number {info} is exactly zero. Singular matrix."
,
LinAlgWarning
,
stacklevel
=
2
,
)
outputs
[
0
][
0
]
=
LU
outputs
[
0
][
0
]
=
LU
outputs
[
1
][
0
]
=
p
outputs
[
1
][
0
]
=
p
...
@@ -782,7 +728,7 @@ def lu_factor(
...
@@ -782,7 +728,7 @@ def lu_factor(
a: TensorLike
a: TensorLike
Matrix to be factorized
Matrix to be factorized
check_finite: bool
check_finite: bool
Whether to check that the input matrix contains only finite number
s.
Unused by PyTensor. PyTensor will return nan if the operation fail
s.
overwrite_a: bool
overwrite_a: bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
...
@@ -796,7 +742,7 @@ def lu_factor(
...
@@ -796,7 +742,7 @@ def lu_factor(
return
cast
(
return
cast
(
tuple
[
TensorVariable
,
TensorVariable
],
tuple
[
TensorVariable
,
TensorVariable
],
Blockwise
(
LUFactor
(
check_finite
=
check_finite
))(
a
),
Blockwise
(
LUFactor
())(
a
),
)
)
...
@@ -806,7 +752,6 @@ def _lu_solve(
...
@@ -806,7 +752,6 @@ def _lu_solve(
b
:
TensorLike
,
b
:
TensorLike
,
trans
:
bool
=
False
,
trans
:
bool
=
False
,
b_ndim
:
int
|
None
=
None
,
b_ndim
:
int
|
None
=
None
,
check_finite
:
bool
=
True
,
):
):
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
...
@@ -824,7 +769,6 @@ def _lu_solve(
...
@@ -824,7 +769,6 @@ def _lu_solve(
unit_diagonal
=
not
trans
,
unit_diagonal
=
not
trans
,
trans
=
trans
,
trans
=
trans
,
b_ndim
=
b_ndim
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
,
)
)
x
=
solve_triangular
(
x
=
solve_triangular
(
...
@@ -834,7 +778,6 @@ def _lu_solve(
...
@@ -834,7 +778,6 @@ def _lu_solve(
unit_diagonal
=
trans
,
unit_diagonal
=
trans
,
trans
=
trans
,
trans
=
trans
,
b_ndim
=
b_ndim
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
,
)
)
# TODO: Use PermuteRows(inverse=True) on x
# TODO: Use PermuteRows(inverse=True) on x
...
@@ -867,7 +810,7 @@ def lu_solve(
...
@@ -867,7 +810,7 @@ def lu_solve(
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
check_finite: bool
check_finite: bool
If True, check that the input matrices contain only finite numbers. Default is True
.
Unused by PyTensor. PyTensor will return nan if the operation fails
.
overwrite_b: bool
overwrite_b: bool
Ignored by Pytensor. Pytensor will always compute inplace when possible.
Ignored by Pytensor. Pytensor will always compute inplace when possible.
"""
"""
...
@@ -876,9 +819,7 @@ def lu_solve(
...
@@ -876,9 +819,7 @@ def lu_solve(
signature
=
"(m,m),(m),(m)->(m)"
signature
=
"(m,m),(m),(m)->(m)"
else
:
else
:
signature
=
"(m,m),(m),(m,n)->(m,n)"
signature
=
"(m,m),(m),(m,n)->(m,n)"
partialled_func
=
partial
(
partialled_func
=
partial
(
_lu_solve
,
trans
=
trans
,
b_ndim
=
b_ndim
)
_lu_solve
,
trans
=
trans
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
)
return
pt
.
vectorize
(
partialled_func
,
signature
=
signature
)(
*
LU_and_pivots
,
b
)
return
pt
.
vectorize
(
partialled_func
,
signature
=
signature
)(
*
LU_and_pivots
,
b
)
...
@@ -888,7 +829,6 @@ class SolveTriangular(SolveBase):
...
@@ -888,7 +829,6 @@ class SolveTriangular(SolveBase):
__props__
=
(
__props__
=
(
"unit_diagonal"
,
"unit_diagonal"
,
"lower"
,
"lower"
,
"check_finite"
,
"b_ndim"
,
"b_ndim"
,
"overwrite_b"
,
"overwrite_b"
,
)
)
...
@@ -905,10 +845,7 @@ class SolveTriangular(SolveBase):
...
@@ -905,10 +845,7 @@ class SolveTriangular(SolveBase):
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
A
,
b
=
inputs
A
,
b
=
inputs
if
self
.
check_finite
and
not
(
np
.
isfinite
(
A
)
.
all
()
and
np
.
isfinite
(
b
)
.
all
()):
if
A
.
ndim
!=
2
or
A
.
shape
[
0
]
!=
A
.
shape
[
1
]:
raise
ValueError
(
"array must not contain infs or NaNs"
)
if
len
(
A
.
shape
)
!=
2
or
A
.
shape
[
0
]
!=
A
.
shape
[
1
]:
raise
ValueError
(
"expected square matrix"
)
raise
ValueError
(
"expected square matrix"
)
if
A
.
shape
[
0
]
!=
b
.
shape
[
0
]:
if
A
.
shape
[
0
]
!=
b
.
shape
[
0
]:
...
@@ -941,12 +878,8 @@ class SolveTriangular(SolveBase):
...
@@ -941,12 +878,8 @@ class SolveTriangular(SolveBase):
unitdiag
=
self
.
unit_diagonal
,
unitdiag
=
self
.
unit_diagonal
,
)
)
if
info
>
0
:
if
info
!=
0
:
raise
LinAlgError
(
x
[
...
]
=
np
.
nan
f
"singular matrix: resolution failed at diagonal {info - 1}"
)
elif
info
<
0
:
raise
ValueError
(
f
"illegal value in {-info}-th argument of internal trtrs"
)
outputs
[
0
][
0
]
=
x
outputs
[
0
][
0
]
=
x
...
@@ -998,9 +931,7 @@ def solve_triangular(
...
@@ -998,9 +931,7 @@ def solve_triangular(
unit_diagonal: bool, optional
unit_diagonal: bool, optional
If True, diagonal elements of `a` are assumed to be 1 and will not be referenced.
If True, diagonal elements of `a` are assumed to be 1 and will not be referenced.
check_finite : bool, optional
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Unused by PyTensor. PyTensor will return nan if the operation fails.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
This will influence how batched dimensions are interpreted.
...
@@ -1018,7 +949,6 @@ def solve_triangular(
...
@@ -1018,7 +949,6 @@ def solve_triangular(
SolveTriangular
(
SolveTriangular
(
lower
=
lower
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
unit_diagonal
=
unit_diagonal
,
check_finite
=
check_finite
,
b_ndim
=
b_ndim
,
b_ndim
=
b_ndim
,
)
)
)(
a
,
b
)
)(
a
,
b
)
...
@@ -1033,7 +963,6 @@ class Solve(SolveBase):
...
@@ -1033,7 +963,6 @@ class Solve(SolveBase):
__props__
=
(
__props__
=
(
"assume_a"
,
"assume_a"
,
"lower"
,
"lower"
,
"check_finite"
,
"b_ndim"
,
"b_ndim"
,
"overwrite_a"
,
"overwrite_a"
,
"overwrite_b"
,
"overwrite_b"
,
...
@@ -1073,15 +1002,18 @@ class Solve(SolveBase):
...
@@ -1073,15 +1002,18 @@ class Solve(SolveBase):
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
a
,
b
=
inputs
a
,
b
=
inputs
outputs
[
0
][
0
]
=
scipy_linalg
.
solve
(
try
:
a
=
a
,
outputs
[
0
][
0
]
=
scipy_linalg
.
solve
(
b
=
b
,
a
=
a
,
lower
=
self
.
lower
,
b
=
b
,
check_finite
=
self
.
check_finite
,
lower
=
self
.
lower
,
assume_a
=
self
.
assume_a
,
check_finite
=
False
,
overwrite_a
=
self
.
overwrite_a
,
assume_a
=
self
.
assume_a
,
overwrite_b
=
self
.
overwrite_b
,
overwrite_a
=
self
.
overwrite_a
,
)
overwrite_b
=
self
.
overwrite_b
,
)
except
np
.
linalg
.
LinAlgError
:
outputs
[
0
][
0
]
=
np
.
full
(
a
.
shape
,
np
.
nan
,
dtype
=
a
.
dtype
)
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
not
allowed_inplace_inputs
:
if
not
allowed_inplace_inputs
:
...
@@ -1152,10 +1084,8 @@ def solve(
...
@@ -1152,10 +1084,8 @@ def solve(
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
overwrite_b : bool
overwrite_b : bool
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
check_finite : bool, optional
check_finite : bool
Whether to check that the input matrices contain only finite numbers.
Unused by PyTensor. PyTensor returns nan if the operation fails.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional
assume_a : str, optional
Valid entries are explained above.
Valid entries are explained above.
transposed: bool, default False
transposed: bool, default False
...
@@ -1174,7 +1104,6 @@ def solve(
...
@@ -1174,7 +1104,6 @@ def solve(
b
,
b
,
lower
=
lower
,
lower
=
lower
,
trans
=
transposed
,
trans
=
transposed
,
check_finite
=
check_finite
,
b_ndim
=
b_ndim
,
b_ndim
=
b_ndim
,
)
)
...
@@ -1195,7 +1124,6 @@ def solve(
...
@@ -1195,7 +1124,6 @@ def solve(
return
Blockwise
(
return
Blockwise
(
Solve
(
Solve
(
lower
=
lower
,
lower
=
lower
,
check_finite
=
check_finite
,
assume_a
=
assume_a
,
assume_a
=
assume_a
,
b_ndim
=
b_ndim
,
b_ndim
=
b_ndim
,
)
)
...
@@ -1779,7 +1707,6 @@ class QR(Op):
...
@@ -1779,7 +1707,6 @@ class QR(Op):
"overwrite_a"
,
"overwrite_a"
,
"mode"
,
"mode"
,
"pivoting"
,
"pivoting"
,
"check_finite"
,
)
)
def
__init__
(
def
__init__
(
...
@@ -1787,12 +1714,10 @@ class QR(Op):
...
@@ -1787,12 +1714,10 @@ class QR(Op):
mode
:
Literal
[
"full"
,
"r"
,
"economic"
,
"raw"
]
=
"full"
,
mode
:
Literal
[
"full"
,
"r"
,
"economic"
,
"raw"
]
=
"full"
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
pivoting
:
bool
=
False
,
pivoting
:
bool
=
False
,
check_finite
:
bool
=
False
,
):
):
self
.
mode
=
mode
self
.
mode
=
mode
self
.
overwrite_a
=
overwrite_a
self
.
overwrite_a
=
overwrite_a
self
.
pivoting
=
pivoting
self
.
pivoting
=
pivoting
self
.
check_finite
=
check_finite
self
.
destroy_map
=
{}
self
.
destroy_map
=
{}
...
...
pytensor/xtensor/linalg.py
浏览文件 @
672a4829
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
Literal
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
from
pytensor.xtensor.type
import
as_xtensor
from
pytensor.xtensor.type
import
as_xtensor
...
@@ -10,8 +9,7 @@ def cholesky(
...
@@ -10,8 +9,7 @@ def cholesky(
x
,
x
,
lower
:
bool
=
True
,
lower
:
bool
=
True
,
*
,
*
,
check_finite
:
bool
=
False
,
check_finite
:
bool
=
True
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"raise"
,
dims
:
Sequence
[
str
],
dims
:
Sequence
[
str
],
):
):
"""Compute the Cholesky decomposition of an XTensorVariable.
"""Compute the Cholesky decomposition of an XTensorVariable.
...
@@ -22,22 +20,15 @@ def cholesky(
...
@@ -22,22 +20,15 @@ def cholesky(
The input variable to decompose.
The input variable to decompose.
lower : bool, optional
lower : bool, optional
Whether to return the lower triangular matrix. Default is True.
Whether to return the lower triangular matrix. Default is True.
check_finite : bool, optional
check_finite : bool
Whether to check that the input is finite. Default is False.
Unused by PyTensor. PyTensor will return nan if the operation fails.
on_error : {'raise', 'nan'}, optional
What to do if the input is not positive definite. If 'raise', an error is raised.
If 'nan', the output will contain NaNs. Default is 'raise'.
dims : Sequence[str]
dims : Sequence[str]
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
"""
"""
if
len
(
dims
)
!=
2
:
if
len
(
dims
)
!=
2
:
raise
ValueError
(
f
"Cholesky needs two dims, got {len(dims)}"
)
raise
ValueError
(
f
"Cholesky needs two dims, got {len(dims)}"
)
core_op
=
Cholesky
(
core_op
=
Cholesky
(
lower
=
lower
)
lower
=
lower
,
check_finite
=
check_finite
,
on_error
=
on_error
,
)
core_dims
=
(
core_dims
=
(
((
dims
[
0
],
dims
[
1
]),),
((
dims
[
0
],
dims
[
1
]),),
((
dims
[
0
],
dims
[
1
]),),
((
dims
[
0
],
dims
[
1
]),),
...
@@ -52,7 +43,7 @@ def solve(
...
@@ -52,7 +43,7 @@ def solve(
dims
:
Sequence
[
str
],
dims
:
Sequence
[
str
],
assume_a
=
"gen"
,
assume_a
=
"gen"
,
lower
:
bool
=
False
,
lower
:
bool
=
False
,
check_finite
:
bool
=
Fals
e
,
check_finite
:
bool
=
Tru
e
,
):
):
"""Solve a system of linear equations using XTensorVariables.
"""Solve a system of linear equations using XTensorVariables.
...
@@ -75,8 +66,8 @@ def solve(
...
@@ -75,8 +66,8 @@ def solve(
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
lower : bool, optional
lower : bool, optional
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
check_finite : bool
, optional
check_finite : bool
Whether to check that the input is finite. Default is False
.
Unused by PyTensor. PyTensor will return nan if the operation fails
.
"""
"""
a
,
b
=
as_xtensor
(
a
),
as_xtensor
(
b
)
a
,
b
=
as_xtensor
(
a
),
as_xtensor
(
b
)
input_core_dims
:
tuple
[
tuple
[
str
,
str
],
tuple
[
str
]
|
tuple
[
str
,
str
]]
input_core_dims
:
tuple
[
tuple
[
str
,
str
],
tuple
[
str
]
|
tuple
[
str
,
str
]]
...
@@ -98,9 +89,7 @@ def solve(
...
@@ -98,9 +89,7 @@ def solve(
else
:
else
:
raise
ValueError
(
"Solve dims must have length 2 or 3"
)
raise
ValueError
(
"Solve dims must have length 2 or 3"
)
core_op
=
Solve
(
core_op
=
Solve
(
b_ndim
=
b_ndim
,
assume_a
=
assume_a
,
lower
=
lower
)
b_ndim
=
b_ndim
,
assume_a
=
assume_a
,
lower
=
lower
,
check_finite
=
check_finite
)
x_op
=
XBlockwise
(
x_op
=
XBlockwise
(
core_op
,
core_op
,
core_dims
=
(
input_core_dims
,
output_core_dims
),
core_dims
=
(
input_core_dims
,
output_core_dims
),
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
672a4829
import
re
from
typing
import
Literal
from
typing
import
Literal
import
numpy
as
np
import
numpy
as
np
...
@@ -36,70 +35,6 @@ floatX = config.floatX
...
@@ -36,70 +35,6 @@ floatX = config.floatX
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
def
test_lamch
():
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.linalg.utils
import
_xlamch
@numba.njit
()
def
xlamch
(
kind
):
return
_xlamch
(
kind
)
lamch
=
get_lapack_funcs
(
"lamch"
,
(
np
.
array
([
0.0
],
dtype
=
floatX
),))
np
.
testing
.
assert_allclose
(
xlamch
(
"E"
),
lamch
(
"E"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"S"
),
lamch
(
"S"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"P"
),
lamch
(
"P"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"B"
),
lamch
(
"B"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"R"
),
lamch
(
"R"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"M"
),
lamch
(
"M"
))
@pytest.mark.parametrize
(
"ord_numba, ord_scipy"
,
[(
"F"
,
"fro"
),
(
"1"
,
1
),
(
"I"
,
np
.
inf
)]
)
def
test_xlange
(
ord_numba
,
ord_scipy
):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
@numba.njit
()
def
xlange
(
x
,
ord
):
return
_xlange
(
x
,
ord
)
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
np
.
testing
.
assert_allclose
(
xlange
(
x
,
ord_numba
),
linalg
.
norm
(
x
,
ord_scipy
))
@pytest.mark.parametrize
(
"ord_numba, ord_scipy"
,
[(
"1"
,
1
),
(
"I"
,
np
.
inf
)])
def
test_xgecon
(
ord_numba
,
ord_scipy
):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_xgecon
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
@numba.njit
()
def
gecon
(
x
,
norm
):
anorm
=
_xlange
(
x
,
norm
)
cond
,
info
=
_xgecon
(
x
,
anorm
,
norm
)
return
cond
,
info
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
rcond
,
info
=
gecon
(
x
,
norm
=
ord_numba
)
# Test against direct call to the underlying LAPACK functions
# Solution does **not** agree with 1 / np.linalg.cond(x) !
lange
,
gecon
=
get_lapack_funcs
((
"lange"
,
"gecon"
),
(
x
,))
norm
=
lange
(
ord_numba
,
x
)
rcond2
,
_
=
gecon
(
x
,
norm
,
norm
=
ord_numba
)
assert
info
==
0
np
.
testing
.
assert_allclose
(
rcond
,
rcond2
)
class
TestSolves
:
class
TestSolves
:
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower={x}"
)
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower={x}"
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -323,7 +258,7 @@ class TestSolves:
...
@@ -323,7 +258,7 @@ class TestSolves:
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
@pytest.mark.parametrize
(
"value"
,
[
np
.
nan
,
np
.
inf
])
@pytest.mark.parametrize
(
"value"
,
[
np
.
nan
,
np
.
inf
])
def
test_solve_triangular_
raises
_on_nan_inf
(
self
,
value
):
def
test_solve_triangular_
does_not_raise
_on_nan_inf
(
self
,
value
):
A
=
pt
.
matrix
(
"A"
)
A
=
pt
.
matrix
(
"A"
)
b
=
pt
.
matrix
(
"b"
)
b
=
pt
.
matrix
(
"b"
)
...
@@ -335,11 +270,8 @@ class TestSolves:
...
@@ -335,11 +270,8 @@ class TestSolves:
A_tri
=
np
.
linalg
.
cholesky
(
A_sym
)
.
astype
(
floatX
)
A_tri
=
np
.
linalg
.
cholesky
(
A_sym
)
.
astype
(
floatX
)
b
=
np
.
full
((
5
,
1
),
value
)
.
astype
(
floatX
)
b
=
np
.
full
((
5
,
1
),
value
)
.
astype
(
floatX
)
with
pytest
.
raises
(
# Not checking everything is nan, because, with inf, LAPACK returns a mix of inf/nan, but does not set info != 0
np
.
linalg
.
LinAlgError
,
assert
not
np
.
isfinite
(
f
(
A_tri
,
b
))
.
any
()
match
=
re
.
escape
(
"Non-numeric values"
),
):
f
(
A_tri
,
b
)
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower = {x}"
)
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower = {x}"
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -567,10 +499,13 @@ class TestDecompositions:
...
@@ -567,10 +499,13 @@ class TestDecompositions:
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
x
=
x
.
T
.
dot
(
x
)
x
=
x
.
T
.
dot
(
x
)
g
=
pt
.
linalg
.
cholesky
(
x
,
check_finite
=
True
)
with
pytest
.
warns
(
FutureWarning
):
g
=
pt
.
linalg
.
cholesky
(
x
,
check_finite
=
True
,
on_error
=
"raise"
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"Non-numeric values"
):
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"Matrix is not positive definite"
):
f
(
test_value
)
f
(
test_value
)
@pytest.mark.parametrize
(
"on_error"
,
[
"nan"
,
"raise"
])
@pytest.mark.parametrize
(
"on_error"
,
[
"nan"
,
"raise"
])
...
@@ -578,13 +513,17 @@ class TestDecompositions:
...
@@ -578,13 +513,17 @@ class TestDecompositions:
test_value
=
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
test_value
=
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
if
on_error
==
"raise"
:
with
pytest
.
warns
(
FutureWarning
):
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
else
:
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
if
on_error
==
"raise"
:
if
on_error
==
"raise"
:
with
pytest
.
raises
(
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
np
.
linalg
.
LinAlgError
,
match
=
r"
Input to cholesky
is not positive definite"
,
match
=
r"
Matrix
is not positive definite"
,
):
):
f
(
test_value
)
f
(
test_value
)
else
:
else
:
...
...
tests/tensor/linalg/test_rewriting.py
浏览文件 @
672a4829
...
@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
...
@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
resx1
=
fn_opt
(
A_test
,
x0_test
)
resx1
=
fn_opt
(
A_test
,
x0_test
)
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-4
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-4
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
@pytest.mark.parametrize
(
"assume_a, counter"
,
(
(
"gen"
,
LUOpCounter
),
(
"pos"
,
CholeskyOpCounter
),
),
)
def
test_decomposition_reused_preserves_check_finite
(
assume_a
,
counter
):
# Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name
=
reuse_decomposition_multiple_solves
.
__name__
A
=
tensor
(
"A"
,
shape
=
(
2
,
2
))
b1
=
tensor
(
"b1"
,
shape
=
(
2
,))
b2
=
tensor
(
"b2"
,
shape
=
(
2
,))
x1
=
solve
(
A
,
b1
,
assume_a
=
assume_a
,
check_finite
=
True
)
x2
=
solve
(
A
,
b2
,
assume_a
=
assume_a
,
check_finite
=
False
)
fn_opt
=
function
(
[
A
,
b1
,
b2
],
[
x1
,
x2
],
mode
=
get_default_mode
()
.
including
(
rewrite_name
)
)
opt_nodes
=
fn_opt
.
maker
.
fgraph
.
apply_nodes
assert
counter
.
count_vanilla_solve_nodes
(
opt_nodes
)
==
0
assert
counter
.
count_decomp_nodes
(
opt_nodes
)
==
1
assert
counter
.
count_solve_nodes
(
opt_nodes
)
==
2
# We should get an error if A or b1 is non finite
A_valid
=
np
.
array
([[
1
,
0
],
[
0
,
1
]],
dtype
=
A
.
type
.
dtype
)
b1_valid
=
np
.
array
([
1
,
1
],
dtype
=
b1
.
type
.
dtype
)
b2_valid
=
np
.
array
([
1
,
1
],
dtype
=
b2
.
type
.
dtype
)
assert
fn_opt
(
A_valid
,
b1_valid
,
b2_valid
)
# Fine
assert
fn_opt
(
A_valid
,
b1_valid
,
b2_valid
*
np
.
nan
)
# Should not raise (also fine on most LAPACK implementations?)
err_msg
=
(
"(array must not contain infs or NaNs"
r"|Non-numeric values \(nan or inf\))"
)
with
pytest
.
raises
((
ValueError
,
np
.
linalg
.
LinAlgError
),
match
=
err_msg
):
assert
fn_opt
(
A_valid
,
b1_valid
*
np
.
nan
,
b2_valid
)
with
pytest
.
raises
((
ValueError
,
np
.
linalg
.
LinAlgError
),
match
=
err_msg
):
assert
fn_opt
(
A_valid
*
np
.
nan
,
b1_valid
,
b2_valid
)
tests/tensor/test_slinalg.py
浏览文件 @
672a4829
...
@@ -74,9 +74,6 @@ def test_cholesky():
...
@@ -74,9 +74,6 @@ def test_cholesky():
chol
=
Cholesky
(
lower
=
False
)(
x
)
chol
=
Cholesky
(
lower
=
False
)(
x
)
ch_f
=
function
([
x
],
chol
)
ch_f
=
function
([
x
],
chol
)
check_upper_triangular
(
pd
,
ch_f
)
check_upper_triangular
(
pd
,
ch_f
)
chol
=
Cholesky
(
lower
=
False
,
on_error
=
"nan"
)(
x
)
ch_f
=
function
([
x
],
chol
)
check_upper_triangular
(
pd
,
ch_f
)
def
test_cholesky_performance
(
benchmark
):
def
test_cholesky_performance
(
benchmark
):
...
@@ -102,12 +99,15 @@ def test_cholesky_empty():
...
@@ -102,12 +99,15 @@ def test_cholesky_empty():
def
test_cholesky_indef
():
def
test_cholesky_indef
():
x
=
matrix
()
x
=
matrix
()
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
cholesky
(
x
))
with
pytest
.
warns
(
FutureWarning
):
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
out
)
with
pytest
.
raises
(
scipy
.
linalg
.
LinAlgError
):
with
pytest
.
raises
(
scipy
.
linalg
.
LinAlgError
):
chol_f
(
mat
)
chol_f
(
mat
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
cholesky
(
x
))
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
out
)
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
...
@@ -143,12 +143,16 @@ def test_cholesky_grad():
...
@@ -143,12 +143,16 @@ def test_cholesky_grad():
def
test_cholesky_grad_indef
():
def
test_cholesky_grad_indef
():
x
=
matrix
()
x
=
matrix
()
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
grad
(
cholesky
(
x
)
.
sum
(),
[
x
]))
with
pytest
.
warns
(
FutureWarning
):
with
pytest
.
raises
(
scipy
.
linalg
.
LinAlgError
):
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"raise"
)
chol_f
(
mat
)
chol_f
=
function
([
x
],
grad
(
out
.
sum
(),
[
x
]),
mode
=
"FAST_RUN"
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
grad
(
cholesky
(
x
)
.
sum
(),
[
x
]))
# original cholesky doesn't show up in the grad (if mode="FAST_RUN"), so it does not raise
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
grad
(
out
.
sum
(),
[
x
]))
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
...
@@ -237,7 +241,7 @@ class TestSolveBase:
...
@@ -237,7 +241,7 @@ class TestSolveBase:
y
=
self
.
SolveTest
(
b_ndim
=
2
)(
A
,
b
)
y
=
self
.
SolveTest
(
b_ndim
=
2
)(
A
,
b
)
assert
(
assert
(
y
.
__repr__
()
y
.
__repr__
()
==
"SolveTest{lower=False,
check_finite=True,
b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
==
"SolveTest{lower=False, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
)
...
@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester):
...
@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def
test_repr
(
self
):
def
test_repr
(
self
):
assert
(
assert
(
repr
(
CholeskySolve
(
lower
=
True
,
b_ndim
=
1
))
repr
(
CholeskySolve
(
lower
=
True
,
b_ndim
=
1
))
==
"CholeskySolve(lower=True,
check_finite=True,
b_ndim=1,overwrite_b=False)"
==
"CholeskySolve(lower=True,b_ndim=1,overwrite_b=False)"
)
)
def
test_infer_shape
(
self
):
def
test_infer_shape
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论