Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
672a4829
提交
672a4829
authored
1月 08, 2026
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
1月 11, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Do not raise in linalg Ops
上级
b2d8bc24
全部展开
显示空白字符变更
内嵌
并排
正在显示
23 个修改的文件
包含
144 行增加
和
887 行删除
+144
-887
slinalg.py
pytensor/link/jax/dispatch/slinalg.py
+4
-10
_LAPACK.py
pytensor/link/numba/dispatch/linalg/_LAPACK.py
+0
-201
cholesky.py
...nsor/link/numba/dispatch/linalg/decomposition/cholesky.py
+11
-13
lu.py
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
+6
-15
lu_factor.py
...sor/link/numba/dispatch/linalg/decomposition/lu_factor.py
+4
-3
qr.py
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
+12
-34
cholesky.py
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
+9
-8
general.py
pytensor/link/numba/dispatch/linalg/solve/general.py
+12
-74
lu_solve.py
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
+5
-12
norm.py
pytensor/link/numba/dispatch/linalg/solve/norm.py
+0
-55
posdef.py
pytensor/link/numba/dispatch/linalg/solve/posdef.py
+6
-67
symmetric.py
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
+5
-59
triangular.py
pytensor/link/numba/dispatch/linalg/solve/triangular.py
+6
-8
tridiagonal.py
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
+16
-92
utils.py
pytensor/link/numba/dispatch/linalg/utils.py
+2
-64
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+0
-0
tensor_basic.py
pytensor/link/numba/dispatch/tensor_basic.py
+2
-2
rewriting.py
pytensor/tensor/_linalg/solve/rewriting.py
+4
-18
slinalg.py
pytensor/tensor/slinalg.py
+0
-0
linalg.py
pytensor/xtensor/linalg.py
+8
-19
test_slinalg.py
tests/link/numba/test_slinalg.py
+13
-74
test_rewriting.py
tests/tensor/linalg/test_rewriting.py
+0
-44
test_slinalg.py
tests/tensor/test_slinalg.py
+19
-15
没有找到文件。
pytensor/link/jax/dispatch/slinalg.py
浏览文件 @
672a4829
...
@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs):
...
@@ -92,7 +92,6 @@ def jax_funcify_Solve(op, **kwargs):
def
jax_funcify_SolveTriangular
(
op
,
**
kwargs
):
def
jax_funcify_SolveTriangular
(
op
,
**
kwargs
):
lower
=
op
.
lower
lower
=
op
.
lower
unit_diagonal
=
op
.
unit_diagonal
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
def
solve_triangular
(
A
,
b
):
def
solve_triangular
(
A
,
b
):
return
jax
.
scipy
.
linalg
.
solve_triangular
(
return
jax
.
scipy
.
linalg
.
solve_triangular
(
...
@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
...
@@ -101,7 +100,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
lower
=
lower
,
lower
=
lower
,
trans
=
0
,
# this is handled by explicitly transposing A, so it will always be 0 when we get to here.
trans
=
0
,
# this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal
=
unit_diagonal
,
unit_diagonal
=
unit_diagonal
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
)
)
return
solve_triangular
return
solve_triangular
...
@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs):
...
@@ -132,27 +131,23 @@ def jax_funcify_PivotToPermutation(op, **kwargs):
def
jax_funcify_LU
(
op
,
**
kwargs
):
def
jax_funcify_LU
(
op
,
**
kwargs
):
permute_l
=
op
.
permute_l
permute_l
=
op
.
permute_l
p_indices
=
op
.
p_indices
p_indices
=
op
.
p_indices
check_finite
=
op
.
check_finite
if
p_indices
:
if
p_indices
:
raise
ValueError
(
"JAX does not support the p_indices argument"
)
raise
ValueError
(
"JAX does not support the p_indices argument"
)
def
lu
(
*
inputs
):
def
lu
(
*
inputs
):
return
jax
.
scipy
.
linalg
.
lu
(
return
jax
.
scipy
.
linalg
.
lu
(
*
inputs
,
permute_l
=
permute_l
,
check_finite
=
False
)
*
inputs
,
permute_l
=
permute_l
,
check_finite
=
check_finite
)
return
lu
return
lu
@jax_funcify.register
(
LUFactor
)
@jax_funcify.register
(
LUFactor
)
def
jax_funcify_LUFactor
(
op
,
**
kwargs
):
def
jax_funcify_LUFactor
(
op
,
**
kwargs
):
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
overwrite_a
=
op
.
overwrite_a
def
lu_factor
(
a
):
def
lu_factor
(
a
):
return
jax
.
scipy
.
linalg
.
lu_factor
(
return
jax
.
scipy
.
linalg
.
lu_factor
(
a
,
check_finite
=
check_finit
e
,
overwrite_a
=
overwrite_a
a
,
check_finite
=
Fals
e
,
overwrite_a
=
overwrite_a
)
)
return
lu_factor
return
lu_factor
...
@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs):
...
@@ -161,12 +156,11 @@ def jax_funcify_LUFactor(op, **kwargs):
@jax_funcify.register
(
CholeskySolve
)
@jax_funcify.register
(
CholeskySolve
)
def
jax_funcify_ChoSolve
(
op
,
**
kwargs
):
def
jax_funcify_ChoSolve
(
op
,
**
kwargs
):
lower
=
op
.
lower
lower
=
op
.
lower
check_finite
=
op
.
check_finite
overwrite_b
=
op
.
overwrite_b
overwrite_b
=
op
.
overwrite_b
def
cho_solve
(
c
,
b
):
def
cho_solve
(
c
,
b
):
return
jax
.
scipy
.
linalg
.
cho_solve
(
return
jax
.
scipy
.
linalg
.
cho_solve
(
(
c
,
lower
),
b
,
check_finite
=
check_finit
e
,
overwrite_b
=
overwrite_b
(
c
,
lower
),
b
,
check_finite
=
Fals
e
,
overwrite_b
=
overwrite_b
)
)
return
cho_solve
return
cho_solve
...
...
pytensor/link/numba/dispatch/linalg/_LAPACK.py
浏览文件 @
672a4829
...
@@ -263,122 +263,6 @@ class _LAPACK:
...
@@ -263,122 +263,6 @@ class _LAPACK:
return
potrs
return
potrs
@classmethod
def
numba_xlange
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve, but doesn't correspond to any Op in pytensor.
"""
kind
=
get_blas_kind
(
dtype
)
float_type
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
False
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
True
)
unique_func_name
=
f
"scipy.lapack.{kind}lange"
@numba_basic.numba_njit
def
get_lange_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"lange"
)
return
ptr
lange_function_type
=
types
.
FunctionType
(
float_type
(
nb_i32p
,
# NORM
nb_i32p
,
# M
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# WORK
)
)
@numba_basic.numba_njit
def
lange
(
NORM
,
M
,
N
,
A
,
LDA
,
WORK
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_lange_pointer
,
func_type_ref
=
lange_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
return
fn
(
NORM
,
M
,
N
,
A
,
LDA
,
WORK
)
return
lange
@classmethod
def
numba_xlamch
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Determine machine precision for floating point arithmetic.
"""
kind
=
get_blas_kind
(
dtype
)
float_type
=
_get_nb_float_from_dtype
(
kind
,
return_pointer
=
False
)
unique_func_name
=
f
"scipy.lapack.{kind}lamch"
@numba_basic.numba_njit
def
get_lamch_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"lamch"
)
return
ptr
lamch_function_type
=
types
.
FunctionType
(
float_type
(
# Return type
nb_i32p
,
# CMACH
)
)
@numba_basic.numba_njit
def
lamch
(
CMACH
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_lamch_pointer
,
func_type_ref
=
lamch_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
res
=
fn
(
CMACH
)
return
res
return
lamch
@classmethod
def
numba_xgecon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}gecon"
@numba_basic.numba_njit
def
get_gecon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"gecon"
)
return
ptr
gecon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# NORM
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
gecon
(
NORM
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_gecon_pointer
,
func_type_ref
=
gecon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
NORM
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
gecon
@classmethod
@classmethod
def
numba_xgetrf
(
cls
,
dtype
)
->
CPUDispatcher
:
def
numba_xgetrf
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
"""
...
@@ -506,91 +390,6 @@ class _LAPACK:
...
@@ -506,91 +390,6 @@ class _LAPACK:
return
sysv
return
sysv
@classmethod
def
numba_xsycon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}sycon"
@numba_basic.numba_njit
def
get_sycon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"sycon"
)
return
ptr
sycon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# UPLO
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
nb_i32p
,
# IPIV
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
sycon
(
UPLO
,
N
,
A
,
LDA
,
IPIV
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_sycon_pointer
,
func_type_ref
=
sycon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
UPLO
,
N
,
A
,
LDA
,
IPIV
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
sycon
@classmethod
def
numba_xpocon
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
kind
=
get_blas_kind
(
dtype
)
float_pointer
=
_get_nb_float_from_dtype
(
kind
)
unique_func_name
=
f
"scipy.lapack.{kind}pocon"
@numba_basic.numba_njit
def
get_pocon_pointer
():
with
numba
.
objmode
(
ptr
=
types
.
intp
):
ptr
=
get_lapack_ptr
(
dtype
,
"pocon"
)
return
ptr
pocon_function_type
=
types
.
FunctionType
(
types
.
void
(
nb_i32p
,
# UPLO
nb_i32p
,
# N
float_pointer
,
# A
nb_i32p
,
# LDA
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
nb_i32p
,
# IWORK
nb_i32p
,
# INFO
)
)
@numba_basic.numba_njit
def
pocon
(
UPLO
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
):
fn
=
_call_cached_ptr
(
get_ptr_func
=
get_pocon_pointer
,
func_type_ref
=
pocon_function_type
,
unique_func_name_lit
=
unique_func_name
,
)
fn
(
UPLO
,
N
,
A
,
LDA
,
ANORM
,
RCOND
,
WORK
,
IWORK
,
INFO
)
return
pocon
@classmethod
@classmethod
def
numba_xposv
(
cls
,
dtype
)
->
CPUDispatcher
:
def
numba_xposv
(
cls
,
dtype
)
->
CPUDispatcher
:
"""
"""
...
...
pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py
浏览文件 @
672a4829
...
@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
...
@@ -12,24 +12,19 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
def
_cholesky
(
a
,
lower
=
False
,
overwrite_a
=
False
):
return
(
return
linalg
.
cholesky
(
a
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
check_finite
=
False
)
linalg
.
cholesky
(
a
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
),
0
,
)
@overload
(
_cholesky
)
@overload
(
_cholesky
)
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
,
check_finite
=
True
):
def
cholesky_impl
(
A
,
lower
=
0
,
overwrite_a
=
False
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cholesky"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cholesky"
)
dtype
=
A
.
dtype
dtype
=
A
.
dtype
numba_potrf
=
_LAPACK
()
.
numba_xpotrf
(
dtype
)
numba_potrf
=
_LAPACK
()
.
numba_xpotrf
(
dtype
)
def
impl
(
A
,
lower
=
False
,
overwrite_a
=
False
,
check_finite
=
True
):
def
impl
(
A
,
lower
=
False
,
overwrite_a
=
False
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
if
A
.
shape
[
-
2
]
!=
_N
:
if
A
.
shape
[
-
2
]
!=
_N
:
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
...
@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
...
@@ -58,6 +53,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
INFO
,
INFO
,
)
)
if
int_ptr_to_val
(
INFO
)
!=
0
:
A_copy
=
np
.
full_like
(
A_copy
,
np
.
nan
)
return
A_copy
if
lower
:
if
lower
:
for
j
in
range
(
1
,
_N
):
for
j
in
range
(
1
,
_N
):
for
i
in
range
(
j
):
for
i
in
range
(
j
):
...
@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
...
@@ -67,10 +66,9 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
for
i
in
range
(
j
+
1
,
_N
):
for
i
in
range
(
j
+
1
,
_N
):
A_copy
[
i
,
j
]
=
0.0
A_copy
[
i
,
j
]
=
0.0
info_int
=
int_ptr_to_val
(
INFO
)
if
transposed
:
if
transposed
:
return
A_copy
.
T
,
info_int
return
A_copy
.
T
return
A_copy
,
info_int
else
:
return
A_copy
return
impl
return
impl
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
浏览文件 @
672a4829
...
@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
...
@@ -39,7 +39,6 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
def
_lu_1
(
def
_lu_1
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
Literal
[
True
],
permute_l
:
Literal
[
True
],
check_finite
:
bool
,
p_indices
:
Literal
[
False
],
p_indices
:
Literal
[
False
],
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
@@ -52,7 +51,7 @@ def _lu_1(
...
@@ -52,7 +51,7 @@ def _lu_1(
return
linalg
.
lu
(
# type: ignore[no-any-return]
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
)
)
...
@@ -61,7 +60,6 @@ def _lu_1(
...
@@ -61,7 +60,6 @@ def _lu_1(
def
_lu_2
(
def
_lu_2
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
Literal
[
False
],
permute_l
:
Literal
[
False
],
check_finite
:
bool
,
p_indices
:
Literal
[
True
],
p_indices
:
Literal
[
True
],
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
...
@@ -74,7 +72,7 @@ def _lu_2(
...
@@ -74,7 +72,7 @@ def _lu_2(
return
linalg
.
lu
(
# type: ignore[no-any-return]
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
)
)
...
@@ -83,7 +81,6 @@ def _lu_2(
...
@@ -83,7 +81,6 @@ def _lu_2(
def
_lu_3
(
def
_lu_3
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
Literal
[
False
],
permute_l
:
Literal
[
False
],
check_finite
:
bool
,
p_indices
:
Literal
[
False
],
p_indices
:
Literal
[
False
],
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
@@ -96,7 +93,7 @@ def _lu_3(
...
@@ -96,7 +93,7 @@ def _lu_3(
return
linalg
.
lu
(
# type: ignore[no-any-return]
return
linalg
.
lu
(
# type: ignore[no-any-return]
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
)
)
...
@@ -106,11 +103,10 @@ def _lu_3(
...
@@ -106,11 +103,10 @@ def _lu_3(
def
lu_impl_1
(
def
lu_impl_1
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[
)
->
Callable
[
[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
]:
]:
"""
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...
@@ -123,7 +119,6 @@ def lu_impl_1(
...
@@ -123,7 +119,6 @@ def lu_impl_1(
def
impl
(
def
impl
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
@@ -137,10 +132,9 @@ def lu_impl_1(
...
@@ -137,10 +132,9 @@ def lu_impl_1(
def
lu_impl_2
(
def
lu_impl_2
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
)
->
Callable
[[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
"""
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
...
@@ -153,7 +147,6 @@ def lu_impl_2(
...
@@ -153,7 +147,6 @@ def lu_impl_2(
def
impl
(
def
impl
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
...
@@ -169,11 +162,10 @@ def lu_impl_2(
...
@@ -169,11 +162,10 @@ def lu_impl_2(
def
lu_impl_3
(
def
lu_impl_3
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[
)
->
Callable
[
[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
[
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
]:
]:
"""
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
...
@@ -186,7 +178,6 @@ def lu_impl_3(
...
@@ -186,7 +178,6 @@ def lu_impl_3(
def
impl
(
def
impl
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
浏览文件 @
672a4829
...
@@ -79,11 +79,12 @@ def lu_factor_impl(
...
@@ -79,11 +79,12 @@ def lu_factor_impl(
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_factor"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_factor"
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
A_copy
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
A_copy
,
IPIV
,
info
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
IPIV
-=
1
# LAPACK uses 1-based indexing, convert to 0-based
IPIV
-=
1
# LAPACK uses 1-based indexing, convert to 0-based
if
INFO
!=
0
:
if
info
!=
0
:
raise
np
.
linalg
.
LinAlgError
(
"LU decomposition failed"
)
A_copy
=
np
.
full_like
(
A_copy
,
np
.
nan
)
return
A_copy
,
IPIV
return
A_copy
,
IPIV
return
impl
return
impl
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
浏览文件 @
672a4829
...
@@ -228,7 +228,6 @@ def _qr_full_pivot(
...
@@ -228,7 +228,6 @@ def _qr_full_pivot(
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
pivoting
:
Literal
[
True
]
=
True
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -243,7 +242,7 @@ def _qr_full_pivot(
...
@@ -243,7 +242,7 @@ def _qr_full_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -253,7 +252,6 @@ def _qr_full_no_pivot(
...
@@ -253,7 +252,6 @@ def _qr_full_no_pivot(
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
pivoting
:
Literal
[
False
]
=
False
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -267,7 +265,7 @@ def _qr_full_no_pivot(
...
@@ -267,7 +265,7 @@ def _qr_full_no_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -277,7 +275,6 @@ def _qr_r_pivot(
...
@@ -277,7 +275,6 @@ def _qr_r_pivot(
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
pivoting
:
Literal
[
True
]
=
True
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -291,7 +288,7 @@ def _qr_r_pivot(
...
@@ -291,7 +288,7 @@ def _qr_r_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -301,7 +298,6 @@ def _qr_r_no_pivot(
...
@@ -301,7 +298,6 @@ def _qr_r_no_pivot(
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
pivoting
:
Literal
[
False
]
=
False
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -315,7 +311,7 @@ def _qr_r_no_pivot(
...
@@ -315,7 +311,7 @@ def _qr_r_no_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -325,7 +321,6 @@ def _qr_raw_no_pivot(
...
@@ -325,7 +321,6 @@ def _qr_raw_no_pivot(
mode
:
Literal
[
"raw"
]
=
"raw"
,
mode
:
Literal
[
"raw"
]
=
"raw"
,
pivoting
:
Literal
[
False
]
=
False
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -339,7 +334,7 @@ def _qr_raw_no_pivot(
...
@@ -339,7 +334,7 @@ def _qr_raw_no_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -351,7 +346,6 @@ def _qr_raw_pivot(
...
@@ -351,7 +346,6 @@ def _qr_raw_pivot(
mode
:
Literal
[
"raw"
]
=
"raw"
,
mode
:
Literal
[
"raw"
]
=
"raw"
,
pivoting
:
Literal
[
True
]
=
True
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
):
):
"""
"""
...
@@ -365,7 +359,7 @@ def _qr_raw_pivot(
...
@@ -365,7 +359,7 @@ def _qr_raw_pivot(
mode
=
mode
,
mode
=
mode
,
pivoting
=
pivoting
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
lwork
=
lwork
,
lwork
=
lwork
,
)
)
...
@@ -373,9 +367,7 @@ def _qr_raw_pivot(
...
@@ -373,9 +367,7 @@ def _qr_raw_pivot(
@overload
(
_qr_full_pivot
)
@overload
(
_qr_full_pivot
)
def
qr_full_pivot_impl
(
def
qr_full_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
x
,
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
dtype
=
x
.
dtype
...
@@ -395,7 +387,6 @@ def qr_full_pivot_impl(
...
@@ -395,7 +387,6 @@ def qr_full_pivot_impl(
mode
=
"full"
,
mode
=
"full"
,
pivoting
=
True
,
pivoting
=
True
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
@@ -529,7 +520,7 @@ def qr_full_pivot_impl(
...
@@ -529,7 +520,7 @@ def qr_full_pivot_impl(
@overload
(
_qr_full_no_pivot
)
@overload
(
_qr_full_no_pivot
)
def
qr_full_no_pivot_impl
(
def
qr_full_no_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
x
,
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
...
@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl(
...
@@ -546,7 +537,6 @@ def qr_full_no_pivot_impl(
mode
=
"full"
,
mode
=
"full"
,
pivoting
=
False
,
pivoting
=
False
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl(
...
@@ -645,9 +635,7 @@ def qr_full_no_pivot_impl(
@overload
(
_qr_r_pivot
)
@overload
(
_qr_r_pivot
)
def
qr_r_pivot_impl
(
def
qr_r_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
x
,
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
dtype
=
x
.
dtype
...
@@ -658,7 +646,6 @@ def qr_r_pivot_impl(
...
@@ -658,7 +646,6 @@ def qr_r_pivot_impl(
mode
=
"r"
,
mode
=
"r"
,
pivoting
=
True
,
pivoting
=
True
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
@@ -720,9 +707,7 @@ def qr_r_pivot_impl(
...
@@ -720,9 +707,7 @@ def qr_r_pivot_impl(
@overload
(
_qr_r_no_pivot
)
@overload
(
_qr_r_no_pivot
)
def
qr_r_no_pivot_impl
(
def
qr_r_no_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
x
,
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
dtype
=
x
.
dtype
...
@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl(
...
@@ -733,7 +718,6 @@ def qr_r_no_pivot_impl(
mode
=
"r"
,
mode
=
"r"
,
pivoting
=
False
,
pivoting
=
False
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl(
...
@@ -792,9 +776,7 @@ def qr_r_no_pivot_impl(
@overload
(
_qr_raw_no_pivot
)
@overload
(
_qr_raw_no_pivot
)
def
qr_raw_no_pivot_impl
(
def
qr_raw_no_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
lwork
=
None
):
x
,
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
dtype
=
x
.
dtype
dtype
=
x
.
dtype
...
@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl(
...
@@ -805,7 +787,6 @@ def qr_raw_no_pivot_impl(
mode
=
"raw"
,
mode
=
"raw"
,
pivoting
=
False
,
pivoting
=
False
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl(
...
@@ -863,9 +844,7 @@ def qr_raw_no_pivot_impl(
@overload
(
_qr_raw_pivot
)
@overload
(
_qr_raw_pivot
)
def
qr_raw_pivot_impl
(
def
qr_raw_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
lwork
=
None
):
x
,
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
_check_linalg_matrix
(
x
,
ndim
=
2
,
dtype
=
(
Float
,
Complex
),
func_name
=
"qr"
)
...
@@ -880,7 +859,6 @@ def qr_raw_pivot_impl(
...
@@ -880,7 +859,6 @@ def qr_raw_pivot_impl(
mode
=
"raw"
,
mode
=
"raw"
,
pivoting
=
True
,
pivoting
=
True
,
overwrite_a
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
lwork
=
None
,
):
):
M
=
np
.
int32
(
x
.
shape
[
0
])
M
=
np
.
int32
(
x
.
shape
[
0
])
...
...
pytensor/link/numba/dispatch/linalg/solve/cholesky.py
浏览文件 @
672a4829
...
@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import (
...
@@ -14,23 +14,23 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
)
def
_cho_solve
(
def
_cho_solve
(
C
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_b
:
bool
):
C
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
):
"""
"""
Solve a positive-definite linear system using the Cholesky decomposition.
Solve a positive-definite linear system using the Cholesky decomposition.
"""
"""
return
linalg
.
cho_solve
(
return
linalg
.
cho_solve
(
(
C
,
lower
),
b
=
B
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
(
C
,
lower
),
b
=
B
,
overwrite_b
=
overwrite_b
,
check_finite
=
False
,
)
)
@overload
(
_cho_solve
)
@overload
(
_cho_solve
)
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
C
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cho_solve"
)
_check_linalg_matrix
(
C
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"cho_solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"cho_solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"cho_solve"
)
...
@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
...
@@ -38,7 +38,7 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
dtype
=
C
.
dtype
dtype
=
C
.
dtype
numba_potrs
=
_LAPACK
()
.
numba_xpotrs
(
dtype
)
numba_potrs
=
_LAPACK
()
.
numba_xpotrs
(
dtype
)
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
):
_solve_check_input_shapes
(
C
,
B
)
_solve_check_input_shapes
(
C
,
B
)
_N
=
np
.
int32
(
C
.
shape
[
-
1
])
_N
=
np
.
int32
(
C
.
shape
[
-
1
])
...
@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
...
@@ -79,7 +79,8 @@ def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
INFO
,
INFO
,
)
)
_solve_check
(
_N
,
int_ptr_to_val
(
INFO
))
if
int_ptr_to_val
(
INFO
)
!=
0
:
B_copy
=
np
.
full_like
(
B_copy
,
np
.
nan
)
if
B_is_1d
:
if
B_is_1d
:
return
B_copy
[
...
,
0
]
return
B_copy
[
...
,
0
]
...
...
pytensor/link/numba/dispatch/linalg/solve/general.py
浏览文件 @
672a4829
...
@@ -3,82 +3,24 @@ from collections.abc import Callable
...
@@ -3,82 +3,24 @@ from collections.abc import Callable
import
numpy
as
np
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.solve.lu_solve
import
_getrs
from
pytensor.link.numba.dispatch.linalg.solve.lu_solve
import
_getrs
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_solve_check
,
)
)
def
_xgecon
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return
# type: ignore
@overload
(
_xgecon
)
def
xgecon_impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
Callable
[[
np
.
ndarray
,
float
,
str
],
tuple
[
np
.
ndarray
,
int
]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"gecon"
)
dtype
=
A
.
dtype
numba_gecon
=
_LAPACK
()
.
numba_xgecon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
A_NORM
=
np
.
array
(
A_norm
,
dtype
=
dtype
)
NORM
=
val_to_int_ptr
(
ord
(
norm
))
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
4
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
1
)
numba_gecon
(
NORM
,
N
,
A_copy
.
ctypes
,
LDA
,
A_NORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_gen
(
def
_solve_gen
(
A
:
np
.
ndarray
,
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
):
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
...
@@ -89,7 +31,7 @@ def _solve_gen(
...
@@ -89,7 +31,7 @@ def _solve_gen(
lower
=
lower
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
assume_a
=
"gen"
,
assume_a
=
"gen"
,
transposed
=
transposed
,
transposed
=
transposed
,
)
)
...
@@ -102,9 +44,8 @@ def solve_gen_impl(
...
@@ -102,9 +44,8 @@ def solve_gen_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
@@ -116,7 +57,6 @@ def solve_gen_impl(
...
@@ -116,7 +57,6 @@ def solve_gen_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
...
@@ -127,20 +67,18 @@ def solve_gen_impl(
...
@@ -127,20 +67,18 @@ def solve_gen_impl(
A
=
A
.
T
A
=
A
.
T
transposed
=
not
transposed
transposed
=
not
transposed
order
=
"I"
if
transposed
else
"1"
LU
,
IPIV
,
INFO1
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
norm
=
_xlange
(
A
,
order
=
order
)
N
=
A
.
shape
[
1
]
X
,
INFO2
=
_getrs
(
LU
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
LU
=
LU
,
_solve_check
(
N
,
INFO
)
B
=
B
,
IPIV
=
IPIV
,
X
,
INFO
=
_getrs
(
trans
=
transposed
,
LU
=
LU
,
B
=
B
,
IPIV
=
IPIV
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
overwrite_b
=
overwrite_b
,
)
)
_solve_check
(
N
,
INFO
)
RCOND
,
INFO
=
_xgecon
(
LU
,
norm
,
"1"
)
if
INFO1
!=
0
or
INFO2
!=
0
:
_solve_check
(
N
,
INFO
,
True
,
RCOND
)
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
return
X
...
...
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
浏览文件 @
672a4829
...
@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
...
@@ -17,7 +17,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
_trans_char_to_int
,
)
)
...
@@ -107,14 +106,11 @@ def _lu_solve(
...
@@ -107,14 +106,11 @@ def _lu_solve(
b
:
np
.
ndarray
,
b
:
np
.
ndarray
,
trans
:
_Trans
,
trans
:
_Trans
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
):
):
"""
"""
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
"""
"""
return
linalg
.
lu_solve
(
return
linalg
.
lu_solve
(
lu_and_piv
,
b
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
lu_and_piv
,
b
,
trans
=
trans
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
)
@overload
(
_lu_solve
)
@overload
(
_lu_solve
)
...
@@ -123,8 +119,7 @@ def lu_solve_impl(
...
@@ -123,8 +119,7 @@ def lu_solve_impl(
b
:
np
.
ndarray
,
b
:
np
.
ndarray
,
trans
:
_Trans
,
trans
:
_Trans
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
ensure_lapack
()
lu
,
_piv
=
lu_and_piv
lu
,
_piv
=
lu_and_piv
_check_linalg_matrix
(
lu
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_solve"
)
_check_linalg_matrix
(
lu
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"lu_solve"
)
...
@@ -137,13 +132,11 @@ def lu_solve_impl(
...
@@ -137,13 +132,11 @@ def lu_solve_impl(
b
:
np
.
ndarray
,
b
:
np
.
ndarray
,
trans
:
_Trans
,
trans
:
_Trans
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
n
=
np
.
int32
(
lu
.
shape
[
0
])
X
,
info
=
_getrs
(
LU
=
lu
,
B
=
b
,
IPIV
=
piv
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
X
,
INFO
=
_getrs
(
LU
=
lu
,
B
=
b
,
IPIV
=
piv
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
_solve_check
(
n
,
INFO
)
if
info
!=
0
:
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
return
X
...
...
pytensor/link/numba/dispatch/linalg/solve/norm.py
deleted
100644 → 0
浏览文件 @
b2d8bc24
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.types
import
Float
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_linalg_matrix
def
_xlange
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
float
:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return
# type: ignore
@overload
(
_xlange
)
def
xlange_impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
Callable
[[
np
.
ndarray
,
str
],
float
]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"norm"
)
dtype
=
A
.
dtype
numba_lange
=
_LAPACK
()
.
numba_xlange
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
):
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
NORM
=
(
val_to_int_ptr
(
ord
(
order
))
if
order
is
not
None
else
val_to_int_ptr
(
ord
(
"1"
))
)
WORK
=
np
.
empty
(
_M
,
dtype
=
dtype
)
# type: ignore
result
=
numba_lange
(
NORM
,
M
,
N
,
A_copy
.
ctypes
,
LDA
,
WORK
.
ctypes
)
return
result
return
impl
pytensor/link/numba/dispatch/linalg/solve/posdef.py
浏览文件 @
672a4829
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val
,
int_ptr_to_val
,
val_to_int_ptr
,
val_to_int_ptr
,
)
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
)
...
@@ -27,8 +25,6 @@ def _posv(
...
@@ -27,8 +25,6 @@ def _posv(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
...
@@ -43,10 +39,8 @@ def posv_impl(
...
@@ -43,10 +39,8 @@ def posv_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
],
]:
]:
ensure_lapack
()
ensure_lapack
()
...
@@ -62,8 +56,6 @@ def posv_impl(
...
@@ -62,8 +56,6 @@ def posv_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_solve_check_input_shapes
(
A
,
B
)
_solve_check_input_shapes
(
A
,
B
)
...
@@ -115,60 +107,12 @@ def posv_impl(
...
@@ -115,60 +107,12 @@ def posv_impl(
return
impl
return
impl
def
_pocon
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return
# type: ignore
@overload
(
_pocon
)
def
pocon_impl
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"pocon"
)
dtype
=
A
.
dtype
numba_pocon
=
_LAPACK
()
.
numba_xpocon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
anorm
:
float
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
))
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
3
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_pocon
(
UPLO
,
N
,
A_copy
.
ctypes
,
LDA
,
ANORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_psd
(
def
_solve_psd
(
A
:
np
.
ndarray
,
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
):
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
...
@@ -179,7 +123,7 @@ def _solve_psd(
...
@@ -179,7 +123,7 @@ def _solve_psd(
lower
=
lower
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
transposed
=
transposed
,
transposed
=
transposed
,
assume_a
=
"pos"
,
assume_a
=
"pos"
,
)
)
...
@@ -192,9 +136,8 @@ def solve_psd_impl(
...
@@ -192,9 +136,8 @@ def solve_psd_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
@@ -206,18 +149,14 @@ def solve_psd_impl(
...
@@ -206,18 +149,14 @@ def solve_psd_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
_solve_check_input_shapes
(
A
,
B
)
C
,
x
,
info
=
_posv
(
_C
,
x
,
info
=
_posv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
,
check_finite
,
transposed
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
rcond
,
info
=
_pocon
(
C
,
_xlange
(
A
))
if
info
!=
0
:
_solve_check
(
A
.
shape
[
-
1
],
info
=
info
,
lamch
=
True
,
rcond
=
rcond
)
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
return
x
...
...
pytensor/link/numba/dispatch/linalg/solve/symmetric.py
浏览文件 @
672a4829
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
...
@@ -11,13 +11,11 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val
,
int_ptr_to_val
,
val_to_int_ptr
,
val_to_int_ptr
,
)
)
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
)
)
...
@@ -121,61 +119,12 @@ def sysv_impl(
...
@@ -121,61 +119,12 @@ def sysv_impl(
return
impl
return
impl
def
_sycon
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return
# type: ignore
@overload
(
_sycon
)
def
sycon_impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"sycon"
)
dtype
=
A
.
dtype
numba_sycon
=
_LAPACK
()
.
numba_xsycon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
UPLO
=
val_to_int_ptr
(
ord
(
"U"
))
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
2
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_sycon
(
UPLO
,
N
,
A_copy
.
ctypes
,
LDA
,
ipiv
.
ctypes
,
ANORM
.
ctypes
,
RCOND
.
ctypes
,
WORK
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_symmetric
(
def
_solve_symmetric
(
A
:
np
.
ndarray
,
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
):
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
...
@@ -186,7 +135,7 @@ def _solve_symmetric(
...
@@ -186,7 +135,7 @@ def _solve_symmetric(
lower
=
lower
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
assume_a
=
"sym"
,
assume_a
=
"sym"
,
transposed
=
transposed
,
transposed
=
transposed
,
)
)
...
@@ -199,9 +148,8 @@ def solve_symmetric_impl(
...
@@ -199,9 +148,8 @@ def solve_symmetric_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
@@ -213,16 +161,14 @@ def solve_symmetric_impl(
...
@@ -213,16 +161,14 @@ def solve_symmetric_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
_solve_check_input_shapes
(
A
,
B
)
lu
,
x
,
ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
_lu
,
x
,
_ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
rcond
,
info
=
_sycon
(
lu
,
ipiv
,
_xlange
(
A
,
order
=
"I"
))
if
info
!=
0
:
_solve_check
(
A
.
shape
[
-
1
],
info
,
True
,
rcond
)
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
return
x
...
...
pytensor/link/numba/dispatch/linalg/solve/triangular.py
浏览文件 @
672a4829
...
@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import (
...
@@ -15,13 +15,12 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
_trans_char_to_int
,
)
)
def
_solve_triangular
(
def
_solve_triangular
(
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
b_ndim
=
1
,
overwrite_b
=
False
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
overwrite_b
=
False
):
):
"""
"""
Thin wrapper around scipy.linalg.solve_triangular.
Thin wrapper around scipy.linalg.solve_triangular.
...
@@ -39,11 +38,12 @@ def _solve_triangular(
...
@@ -39,11 +38,12 @@ def _solve_triangular(
lower
=
lower
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
False
,
)
)
@overload
(
_solve_triangular
)
@overload
(
_solve_triangular
)
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
overwrite_b
):
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve_triangular"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve_triangular"
)
...
@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
...
@@ -57,12 +57,10 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
"This function is not expected to work with complex numbers yet"
"This function is not expected to work with complex numbers yet"
)
)
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
overwrite_b
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
_solve_check_input_shapes
(
A
,
B
)
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
B_is_1d
=
B
.
ndim
==
1
B_is_1d
=
B
.
ndim
==
1
if
A
.
flags
.
f_contiguous
or
(
A
.
flags
.
c_contiguous
and
trans
in
(
0
,
1
)):
if
A
.
flags
.
f_contiguous
or
(
A
.
flags
.
c_contiguous
and
trans
in
(
0
,
1
)):
...
@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
...
@@ -106,8 +104,8 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
LDB
,
LDB
,
INFO
,
INFO
,
)
)
if
int_ptr_to_val
(
INFO
)
!=
0
:
_solve_check
(
int_ptr_to_val
(
LDA
),
int_ptr_to_val
(
INFO
)
)
B_copy
=
np
.
full_like
(
B_copy
,
np
.
nan
)
if
B_is_1d
:
if
B_is_1d
:
return
B_copy
[
...
,
0
]
return
B_copy
[
...
,
0
]
...
...
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
浏览文件 @
672a4829
...
@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
...
@@ -23,7 +23,6 @@ from pytensor.link.numba.dispatch.linalg.utils import (
_check_dtypes_match
,
_check_dtypes_match
,
_check_linalg_matrix
,
_check_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
_trans_char_to_int
,
)
)
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
...
@@ -202,83 +201,12 @@ def gttrs_impl(
...
@@ -202,83 +201,12 @@ def gttrs_impl(
return
impl
return
impl
def
_gtcon
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
tuple
[
ndarray
,
int
]:
"""Placeholder for computing the condition number of a tridiagonal system."""
return
# type: ignore
@overload
(
_gtcon
)
def
gtcon_impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
Callable
[
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
float
,
str
],
tuple
[
ndarray
,
int
]
]:
ensure_lapack
()
_check_linalg_matrix
(
dl
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
d
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
du
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_linalg_matrix
(
du2
,
ndim
=
1
,
dtype
=
Float
,
func_name
=
"gtcon"
)
_check_dtypes_match
((
dl
,
d
,
du
,
du2
),
func_name
=
"gtcon"
)
_check_linalg_matrix
(
ipiv
,
ndim
=
1
,
dtype
=
int32
,
func_name
=
"gtcon"
)
dtype
=
d
.
dtype
numba_gtcon
=
_LAPACK
()
.
numba_xgtcon
(
dtype
)
def
impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
tuple
[
ndarray
,
int
]:
n
=
np
.
int32
(
d
.
shape
[
-
1
])
rcond
=
np
.
empty
(
1
,
dtype
=
dtype
)
work
=
np
.
empty
(
2
*
n
,
dtype
=
dtype
)
iwork
=
np
.
empty
(
n
,
dtype
=
np
.
int32
)
info
=
val_to_int_ptr
(
0
)
numba_gtcon
(
val_to_int_ptr
(
ord
(
norm
)),
val_to_int_ptr
(
n
),
dl
.
ctypes
,
d
.
ctypes
,
du
.
ctypes
,
du2
.
ctypes
,
ipiv
.
ctypes
,
np
.
array
(
anorm
,
dtype
=
dtype
)
.
ctypes
,
rcond
.
ctypes
,
work
.
ctypes
,
iwork
.
ctypes
,
info
,
)
return
rcond
,
int_ptr_to_val
(
info
)
return
impl
def
_solve_tridiagonal
(
def
_solve_tridiagonal
(
a
:
ndarray
,
a
:
ndarray
,
b
:
ndarray
,
b
:
ndarray
,
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
):
):
"""
"""
...
@@ -290,7 +218,7 @@ def _solve_tridiagonal(
...
@@ -290,7 +218,7 @@ def _solve_tridiagonal(
lower
=
lower
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finit
e
,
check_finite
=
Fals
e
,
transposed
=
transposed
,
transposed
=
transposed
,
assume_a
=
"tridiagonal"
,
assume_a
=
"tridiagonal"
,
)
)
...
@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl(
...
@@ -303,9 +231,8 @@ def _tridiagonal_solve_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
ndarray
,
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
ndarray
]:
)
->
Callable
[[
ndarray
,
ndarray
,
bool
,
bool
,
bool
,
bool
],
ndarray
]:
ensure_lapack
()
ensure_lapack
()
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
A
,
ndim
=
2
,
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
_check_linalg_matrix
(
B
,
ndim
=
(
1
,
2
),
dtype
=
Float
,
func_name
=
"solve"
)
...
@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl(
...
@@ -317,31 +244,24 @@ def _tridiagonal_solve_impl(
lower
:
bool
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
transposed
:
bool
,
)
->
ndarray
:
)
->
ndarray
:
n
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
_solve_check_input_shapes
(
A
,
B
)
norm
=
"1"
if
transposed
:
if
transposed
:
A
=
A
.
T
A
=
A
.
T
dl
,
d
,
du
=
np
.
diag
(
A
,
-
1
),
np
.
diag
(
A
,
0
),
np
.
diag
(
A
,
1
)
dl
,
d
,
du
=
np
.
diag
(
A
,
-
1
),
np
.
diag
(
A
,
0
),
np
.
diag
(
A
,
1
)
anorm
=
tridiagonal_norm
(
du
,
d
,
dl
)
dl
,
d
,
du
,
du2
,
ipiv
,
info1
=
_gttrf
(
dl
,
d
,
du
,
du2
,
IPIV
,
INFO
=
_gttrf
(
dl
,
d
,
du
,
overwrite_dl
=
True
,
overwrite_d
=
True
,
overwrite_du
=
True
dl
,
d
,
du
,
overwrite_dl
=
True
,
overwrite_d
=
True
,
overwrite_du
=
True
)
)
_solve_check
(
n
,
INFO
)
X
,
INFO
=
_gttrs
(
X
,
info2
=
_gttrs
(
dl
,
d
,
du
,
du2
,
IPIV
,
B
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
dl
,
d
,
du
,
du2
,
ipiv
,
B
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
)
)
_solve_check
(
n
,
INFO
)
RCOND
,
INFO
=
_gtcon
(
dl
,
d
,
du
,
du2
,
IPIV
,
anorm
,
norm
)
if
info1
!=
0
or
info2
!=
0
:
_solve_check
(
n
,
INFO
,
True
,
RCOND
)
X
=
np
.
full_like
(
X
,
np
.
nan
)
return
X
return
X
...
@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
...
@@ -391,8 +311,8 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
)
)
return
dl
,
d
,
du
,
du2
,
ipiv
return
dl
,
d
,
du
,
du2
,
ipiv
cache_
key
=
1
cache_
version
=
2
return
lu_factor_tridiagonal
,
cache_
key
return
lu_factor_tridiagonal
,
cache_
version
@register_funcify_default_op_cache_key
(
SolveLUFactorTridiagonal
)
@register_funcify_default_op_cache_key
(
SolveLUFactorTridiagonal
)
...
@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
...
@@ -434,7 +354,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
ipiv
=
ipiv
.
astype
(
np
.
int32
)
ipiv
=
ipiv
.
astype
(
np
.
int32
)
if
cast_b
:
if
cast_b
:
b
=
b
.
astype
(
out_dtype
)
b
=
b
.
astype
(
out_dtype
)
x
,
_
=
_gttrs
(
x
,
info
=
_gttrs
(
dl
,
dl
,
d
,
d
,
du
,
du
,
...
@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal(
...
@@ -444,7 +364,11 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
trans
=
transposed
,
trans
=
transposed
,
)
)
if
info
!=
0
:
x
=
np
.
full_like
(
x
,
np
.
nan
)
return
x
return
x
cache_
key
=
1
cache_
version
=
2
return
solve_lu_factor_tridiagonal
,
cache_
key
return
solve_lu_factor_tridiagonal
,
cache_
version
pytensor/link/numba/dispatch/linalg/utils.py
浏览文件 @
672a4829
from
collections.abc
import
Callable
,
Sequence
from
collections.abc
import
Sequence
import
numba
import
numba
from
numba.core
import
types
from
numba.core
import
types
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numpy.linalg
import
LinAlgError
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
val_to_int_ptr
,
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
(
inline
=
"always"
)
...
@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
...
@@ -61,58 +54,3 @@ def _check_dtypes_match(arrays: Sequence, func_name="cho_solve"):
if
first_dtype
!=
other_dtype
:
if
first_dtype
!=
other_dtype
:
msg
=
f
"{func_name} only supported for matching dtypes, got {dtypes}"
msg
=
f
"{func_name} only supported for matching dtypes, got {dtypes}"
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_solve_check
(
n
,
info
,
lamch
=
False
,
rcond
=
None
):
"""
Check arguments during the different steps of the solution phase
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
"""
if
info
<
0
:
# TODO: figure out how to do an fstring here
msg
=
"LAPACK reported an illegal value in input"
raise
ValueError
(
msg
)
elif
0
<
info
:
raise
LinAlgError
(
"Matrix is singular."
)
if
lamch
:
E
=
_xlamch
(
"E"
)
if
rcond
<
E
:
# TODO: This should be a warning, but we can't raise warnings in numba mode
print
(
# noqa: T201
"Ill-conditioned matrix, rcond="
,
rcond
,
", result may not be accurate."
)
def
_xlamch
(
kind
:
str
=
"E"
):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload
(
_xlamch
)
def
xlamch_impl
(
kind
:
str
=
"E"
)
->
Callable
[[
str
],
float
]:
"""
Compute the machine precision for a given floating point type.
"""
from
pytensor
import
config
ensure_lapack
()
w_type
=
_get_underlying_float
(
config
.
floatX
)
if
w_type
==
"float32"
:
dtype
=
types
.
float32
elif
w_type
==
"float64"
:
dtype
=
types
.
float64
else
:
raise
NotImplementedError
(
"Unsupported dtype"
)
numba_lamch
=
_LAPACK
()
.
numba_xlamch
(
dtype
)
def
impl
(
kind
:
str
=
"E"
)
->
float
:
KIND
=
val_to_int_ptr
(
ord
(
kind
))
return
numba_lamch
(
KIND
)
# type: ignore
return
impl
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
672a4829
差异被折叠。
点击展开。
pytensor/link/numba/dispatch/tensor_basic.py
浏览文件 @
672a4829
...
@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
...
@@ -181,8 +181,8 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
out
[
...
,
i
]
=
new_entry
out
[
...
,
i
]
=
new_entry
return
out
return
out
cache_
key
=
1
cache_
version
=
1
return
extract_diag
,
cache_
key
return
extract_diag
,
cache_
version
@register_funcify_default_op_cache_key
(
Eye
)
@register_funcify_default_op_cache_key
(
Eye
)
...
...
pytensor/tensor/_linalg/solve/rewriting.py
浏览文件 @
672a4829
...
@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so
...
@@ -20,14 +20,13 @@ from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_so
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
def
decompose_A
(
A
,
assume_a
,
check_finite
,
lower
):
def
decompose_A
(
A
,
assume_a
,
lower
):
if
assume_a
==
"gen"
:
if
assume_a
==
"gen"
:
return
lu_factor
(
A
,
check_finite
=
check_finite
)
return
lu_factor
(
A
)
elif
assume_a
==
"tridiagonal"
:
elif
assume_a
==
"tridiagonal"
:
# We didn't implement check_finite for tridiagonal LU factorization
return
tridiagonal_lu_factor
(
A
)
return
tridiagonal_lu_factor
(
A
)
elif
assume_a
==
"pos"
:
elif
assume_a
==
"pos"
:
return
cholesky
(
A
,
lower
=
lower
,
check_finite
=
check_finite
)
return
cholesky
(
A
,
lower
=
lower
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -36,7 +35,6 @@ def solve_decomposed_system(
...
@@ -36,7 +35,6 @@ def solve_decomposed_system(
A_decomp
,
b
,
transposed
=
False
,
lower
=
False
,
*
,
core_solve_op
:
Solve
A_decomp
,
b
,
transposed
=
False
,
lower
=
False
,
*
,
core_solve_op
:
Solve
):
):
b_ndim
=
core_solve_op
.
b_ndim
b_ndim
=
core_solve_op
.
b_ndim
check_finite
=
core_solve_op
.
check_finite
assume_a
=
core_solve_op
.
assume_a
assume_a
=
core_solve_op
.
assume_a
if
assume_a
==
"gen"
:
if
assume_a
==
"gen"
:
...
@@ -45,10 +43,8 @@ def solve_decomposed_system(
...
@@ -45,10 +43,8 @@ def solve_decomposed_system(
b
,
b
,
b_ndim
=
b_ndim
,
b_ndim
=
b_ndim
,
trans
=
transposed
,
trans
=
transposed
,
check_finite
=
check_finite
,
)
)
elif
assume_a
==
"tridiagonal"
:
elif
assume_a
==
"tridiagonal"
:
# We didn't implement check_finite for tridiagonal LU solve
return
tridiagonal_lu_solve
(
return
tridiagonal_lu_solve
(
A_decomp
,
A_decomp
,
b
,
b
,
...
@@ -61,7 +57,6 @@ def solve_decomposed_system(
...
@@ -61,7 +57,6 @@ def solve_decomposed_system(
(
A_decomp
,
lower
),
(
A_decomp
,
lower
),
b
,
b
,
b_ndim
=
b_ndim
,
b_ndim
=
b_ndim
,
check_finite
=
check_finite
,
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps(
...
@@ -141,17 +136,8 @@ def _split_decomp_and_solve_steps(
):
):
return
None
return
None
# If any Op had check_finite=True, we also do it for the LU decomposition
check_finite_decomp
=
False
for
client
,
_
in
A_solve_clients_and_transpose
:
if
client
.
op
.
core_op
.
check_finite
:
check_finite_decomp
=
True
break
lower
=
node
.
op
.
core_op
.
lower
lower
=
node
.
op
.
core_op
.
lower
A_decomp
=
decompose_A
(
A_decomp
=
decompose_A
(
A
,
assume_a
=
assume_a
,
lower
=
lower
)
A
,
assume_a
=
assume_a
,
check_finite
=
check_finite_decomp
,
lower
=
lower
)
replacements
=
{}
replacements
=
{}
for
client
,
transposed
in
A_solve_clients_and_transpose
:
for
client
,
transposed
in
A_solve_clients_and_transpose
:
...
...
pytensor/tensor/slinalg.py
浏览文件 @
672a4829
差异被折叠。
点击展开。
pytensor/xtensor/linalg.py
浏览文件 @
672a4829
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
Literal
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
from
pytensor.xtensor.type
import
as_xtensor
from
pytensor.xtensor.type
import
as_xtensor
...
@@ -10,8 +9,7 @@ def cholesky(
...
@@ -10,8 +9,7 @@ def cholesky(
x
,
x
,
lower
:
bool
=
True
,
lower
:
bool
=
True
,
*
,
*
,
check_finite
:
bool
=
False
,
check_finite
:
bool
=
True
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"raise"
,
dims
:
Sequence
[
str
],
dims
:
Sequence
[
str
],
):
):
"""Compute the Cholesky decomposition of an XTensorVariable.
"""Compute the Cholesky decomposition of an XTensorVariable.
...
@@ -22,22 +20,15 @@ def cholesky(
...
@@ -22,22 +20,15 @@ def cholesky(
The input variable to decompose.
The input variable to decompose.
lower : bool, optional
lower : bool, optional
Whether to return the lower triangular matrix. Default is True.
Whether to return the lower triangular matrix. Default is True.
check_finite : bool, optional
check_finite : bool
Whether to check that the input is finite. Default is False.
Unused by PyTensor. PyTensor will return nan if the operation fails.
on_error : {'raise', 'nan'}, optional
What to do if the input is not positive definite. If 'raise', an error is raised.
If 'nan', the output will contain NaNs. Default is 'raise'.
dims : Sequence[str]
dims : Sequence[str]
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
"""
"""
if
len
(
dims
)
!=
2
:
if
len
(
dims
)
!=
2
:
raise
ValueError
(
f
"Cholesky needs two dims, got {len(dims)}"
)
raise
ValueError
(
f
"Cholesky needs two dims, got {len(dims)}"
)
core_op
=
Cholesky
(
core_op
=
Cholesky
(
lower
=
lower
)
lower
=
lower
,
check_finite
=
check_finite
,
on_error
=
on_error
,
)
core_dims
=
(
core_dims
=
(
((
dims
[
0
],
dims
[
1
]),),
((
dims
[
0
],
dims
[
1
]),),
((
dims
[
0
],
dims
[
1
]),),
((
dims
[
0
],
dims
[
1
]),),
...
@@ -52,7 +43,7 @@ def solve(
...
@@ -52,7 +43,7 @@ def solve(
dims
:
Sequence
[
str
],
dims
:
Sequence
[
str
],
assume_a
=
"gen"
,
assume_a
=
"gen"
,
lower
:
bool
=
False
,
lower
:
bool
=
False
,
check_finite
:
bool
=
Fals
e
,
check_finite
:
bool
=
Tru
e
,
):
):
"""Solve a system of linear equations using XTensorVariables.
"""Solve a system of linear equations using XTensorVariables.
...
@@ -75,8 +66,8 @@ def solve(
...
@@ -75,8 +66,8 @@ def solve(
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
lower : bool, optional
lower : bool, optional
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
check_finite : bool
, optional
check_finite : bool
Whether to check that the input is finite. Default is False
.
Unused by PyTensor. PyTensor will return nan if the operation fails
.
"""
"""
a
,
b
=
as_xtensor
(
a
),
as_xtensor
(
b
)
a
,
b
=
as_xtensor
(
a
),
as_xtensor
(
b
)
input_core_dims
:
tuple
[
tuple
[
str
,
str
],
tuple
[
str
]
|
tuple
[
str
,
str
]]
input_core_dims
:
tuple
[
tuple
[
str
,
str
],
tuple
[
str
]
|
tuple
[
str
,
str
]]
...
@@ -98,9 +89,7 @@ def solve(
...
@@ -98,9 +89,7 @@ def solve(
else
:
else
:
raise
ValueError
(
"Solve dims must have length 2 or 3"
)
raise
ValueError
(
"Solve dims must have length 2 or 3"
)
core_op
=
Solve
(
core_op
=
Solve
(
b_ndim
=
b_ndim
,
assume_a
=
assume_a
,
lower
=
lower
)
b_ndim
=
b_ndim
,
assume_a
=
assume_a
,
lower
=
lower
,
check_finite
=
check_finite
)
x_op
=
XBlockwise
(
x_op
=
XBlockwise
(
core_op
,
core_op
,
core_dims
=
(
input_core_dims
,
output_core_dims
),
core_dims
=
(
input_core_dims
,
output_core_dims
),
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
672a4829
import
re
from
typing
import
Literal
from
typing
import
Literal
import
numpy
as
np
import
numpy
as
np
...
@@ -36,70 +35,6 @@ floatX = config.floatX
...
@@ -36,70 +35,6 @@ floatX = config.floatX
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
def
test_lamch
():
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.linalg.utils
import
_xlamch
@numba.njit
()
def
xlamch
(
kind
):
return
_xlamch
(
kind
)
lamch
=
get_lapack_funcs
(
"lamch"
,
(
np
.
array
([
0.0
],
dtype
=
floatX
),))
np
.
testing
.
assert_allclose
(
xlamch
(
"E"
),
lamch
(
"E"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"S"
),
lamch
(
"S"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"P"
),
lamch
(
"P"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"B"
),
lamch
(
"B"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"R"
),
lamch
(
"R"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"M"
),
lamch
(
"M"
))
@pytest.mark.parametrize
(
"ord_numba, ord_scipy"
,
[(
"F"
,
"fro"
),
(
"1"
,
1
),
(
"I"
,
np
.
inf
)]
)
def
test_xlange
(
ord_numba
,
ord_scipy
):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
@numba.njit
()
def
xlange
(
x
,
ord
):
return
_xlange
(
x
,
ord
)
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
np
.
testing
.
assert_allclose
(
xlange
(
x
,
ord_numba
),
linalg
.
norm
(
x
,
ord_scipy
))
@pytest.mark.parametrize
(
"ord_numba, ord_scipy"
,
[(
"1"
,
1
),
(
"I"
,
np
.
inf
)])
def
test_xgecon
(
ord_numba
,
ord_scipy
):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_xgecon
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
@numba.njit
()
def
gecon
(
x
,
norm
):
anorm
=
_xlange
(
x
,
norm
)
cond
,
info
=
_xgecon
(
x
,
anorm
,
norm
)
return
cond
,
info
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
rcond
,
info
=
gecon
(
x
,
norm
=
ord_numba
)
# Test against direct call to the underlying LAPACK functions
# Solution does **not** agree with 1 / np.linalg.cond(x) !
lange
,
gecon
=
get_lapack_funcs
((
"lange"
,
"gecon"
),
(
x
,))
norm
=
lange
(
ord_numba
,
x
)
rcond2
,
_
=
gecon
(
x
,
norm
,
norm
=
ord_numba
)
assert
info
==
0
np
.
testing
.
assert_allclose
(
rcond
,
rcond2
)
class
TestSolves
:
class
TestSolves
:
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower={x}"
)
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower={x}"
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -323,7 +258,7 @@ class TestSolves:
...
@@ -323,7 +258,7 @@ class TestSolves:
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
@pytest.mark.parametrize
(
"value"
,
[
np
.
nan
,
np
.
inf
])
@pytest.mark.parametrize
(
"value"
,
[
np
.
nan
,
np
.
inf
])
def
test_solve_triangular_
raises
_on_nan_inf
(
self
,
value
):
def
test_solve_triangular_
does_not_raise
_on_nan_inf
(
self
,
value
):
A
=
pt
.
matrix
(
"A"
)
A
=
pt
.
matrix
(
"A"
)
b
=
pt
.
matrix
(
"b"
)
b
=
pt
.
matrix
(
"b"
)
...
@@ -335,11 +270,8 @@ class TestSolves:
...
@@ -335,11 +270,8 @@ class TestSolves:
A_tri
=
np
.
linalg
.
cholesky
(
A_sym
)
.
astype
(
floatX
)
A_tri
=
np
.
linalg
.
cholesky
(
A_sym
)
.
astype
(
floatX
)
b
=
np
.
full
((
5
,
1
),
value
)
.
astype
(
floatX
)
b
=
np
.
full
((
5
,
1
),
value
)
.
astype
(
floatX
)
with
pytest
.
raises
(
# Not checking everything is nan, because, with inf, LAPACK returns a mix of inf/nan, but does not set info != 0
np
.
linalg
.
LinAlgError
,
assert
not
np
.
isfinite
(
f
(
A_tri
,
b
))
.
any
()
match
=
re
.
escape
(
"Non-numeric values"
),
):
f
(
A_tri
,
b
)
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower = {x}"
)
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower = {x}"
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -567,10 +499,13 @@ class TestDecompositions:
...
@@ -567,10 +499,13 @@ class TestDecompositions:
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
x
=
x
.
T
.
dot
(
x
)
x
=
x
.
T
.
dot
(
x
)
g
=
pt
.
linalg
.
cholesky
(
x
,
check_finite
=
True
)
with
pytest
.
warns
(
FutureWarning
):
g
=
pt
.
linalg
.
cholesky
(
x
,
check_finite
=
True
,
on_error
=
"raise"
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"Non-numeric values"
):
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"Matrix is not positive definite"
):
f
(
test_value
)
f
(
test_value
)
@pytest.mark.parametrize
(
"on_error"
,
[
"nan"
,
"raise"
])
@pytest.mark.parametrize
(
"on_error"
,
[
"nan"
,
"raise"
])
...
@@ -578,13 +513,17 @@ class TestDecompositions:
...
@@ -578,13 +513,17 @@ class TestDecompositions:
test_value
=
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
test_value
=
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
if
on_error
==
"raise"
:
with
pytest
.
warns
(
FutureWarning
):
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
else
:
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
if
on_error
==
"raise"
:
if
on_error
==
"raise"
:
with
pytest
.
raises
(
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
np
.
linalg
.
LinAlgError
,
match
=
r"
Input to cholesky
is not positive definite"
,
match
=
r"
Matrix
is not positive definite"
,
):
):
f
(
test_value
)
f
(
test_value
)
else
:
else
:
...
...
tests/tensor/linalg/test_rewriting.py
浏览文件 @
672a4829
...
@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
...
@@ -213,47 +213,3 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
resx1
=
fn_opt
(
A_test
,
x0_test
)
resx1
=
fn_opt
(
A_test
,
x0_test
)
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-4
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-4
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
@pytest.mark.parametrize
(
"assume_a, counter"
,
(
(
"gen"
,
LUOpCounter
),
(
"pos"
,
CholeskyOpCounter
),
),
)
def
test_decomposition_reused_preserves_check_finite
(
assume_a
,
counter
):
# Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name
=
reuse_decomposition_multiple_solves
.
__name__
A
=
tensor
(
"A"
,
shape
=
(
2
,
2
))
b1
=
tensor
(
"b1"
,
shape
=
(
2
,))
b2
=
tensor
(
"b2"
,
shape
=
(
2
,))
x1
=
solve
(
A
,
b1
,
assume_a
=
assume_a
,
check_finite
=
True
)
x2
=
solve
(
A
,
b2
,
assume_a
=
assume_a
,
check_finite
=
False
)
fn_opt
=
function
(
[
A
,
b1
,
b2
],
[
x1
,
x2
],
mode
=
get_default_mode
()
.
including
(
rewrite_name
)
)
opt_nodes
=
fn_opt
.
maker
.
fgraph
.
apply_nodes
assert
counter
.
count_vanilla_solve_nodes
(
opt_nodes
)
==
0
assert
counter
.
count_decomp_nodes
(
opt_nodes
)
==
1
assert
counter
.
count_solve_nodes
(
opt_nodes
)
==
2
# We should get an error if A or b1 is non finite
A_valid
=
np
.
array
([[
1
,
0
],
[
0
,
1
]],
dtype
=
A
.
type
.
dtype
)
b1_valid
=
np
.
array
([
1
,
1
],
dtype
=
b1
.
type
.
dtype
)
b2_valid
=
np
.
array
([
1
,
1
],
dtype
=
b2
.
type
.
dtype
)
assert
fn_opt
(
A_valid
,
b1_valid
,
b2_valid
)
# Fine
assert
fn_opt
(
A_valid
,
b1_valid
,
b2_valid
*
np
.
nan
)
# Should not raise (also fine on most LAPACK implementations?)
err_msg
=
(
"(array must not contain infs or NaNs"
r"|Non-numeric values \(nan or inf\))"
)
with
pytest
.
raises
((
ValueError
,
np
.
linalg
.
LinAlgError
),
match
=
err_msg
):
assert
fn_opt
(
A_valid
,
b1_valid
*
np
.
nan
,
b2_valid
)
with
pytest
.
raises
((
ValueError
,
np
.
linalg
.
LinAlgError
),
match
=
err_msg
):
assert
fn_opt
(
A_valid
*
np
.
nan
,
b1_valid
,
b2_valid
)
tests/tensor/test_slinalg.py
浏览文件 @
672a4829
...
@@ -74,9 +74,6 @@ def test_cholesky():
...
@@ -74,9 +74,6 @@ def test_cholesky():
chol
=
Cholesky
(
lower
=
False
)(
x
)
chol
=
Cholesky
(
lower
=
False
)(
x
)
ch_f
=
function
([
x
],
chol
)
ch_f
=
function
([
x
],
chol
)
check_upper_triangular
(
pd
,
ch_f
)
check_upper_triangular
(
pd
,
ch_f
)
chol
=
Cholesky
(
lower
=
False
,
on_error
=
"nan"
)(
x
)
ch_f
=
function
([
x
],
chol
)
check_upper_triangular
(
pd
,
ch_f
)
def
test_cholesky_performance
(
benchmark
):
def
test_cholesky_performance
(
benchmark
):
...
@@ -102,12 +99,15 @@ def test_cholesky_empty():
...
@@ -102,12 +99,15 @@ def test_cholesky_empty():
def
test_cholesky_indef
():
def
test_cholesky_indef
():
x
=
matrix
()
x
=
matrix
()
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
cholesky
(
x
))
with
pytest
.
warns
(
FutureWarning
):
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
out
)
with
pytest
.
raises
(
scipy
.
linalg
.
LinAlgError
):
with
pytest
.
raises
(
scipy
.
linalg
.
LinAlgError
):
chol_f
(
mat
)
chol_f
(
mat
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
cholesky
(
x
))
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
out
)
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
...
@@ -143,12 +143,16 @@ def test_cholesky_grad():
...
@@ -143,12 +143,16 @@ def test_cholesky_grad():
def
test_cholesky_grad_indef
():
def
test_cholesky_grad_indef
():
x
=
matrix
()
x
=
matrix
()
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
mat
=
np
.
array
([[
1
,
0.2
],
[
0.2
,
-
2
]])
.
astype
(
config
.
floatX
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"raise"
)
chol_f
=
function
([
x
],
grad
(
cholesky
(
x
)
.
sum
(),
[
x
]))
with
pytest
.
warns
(
FutureWarning
):
with
pytest
.
raises
(
scipy
.
linalg
.
LinAlgError
):
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"raise"
)
chol_f
(
mat
)
chol_f
=
function
([
x
],
grad
(
out
.
sum
(),
[
x
]),
mode
=
"FAST_RUN"
)
cholesky
=
Cholesky
(
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
grad
(
cholesky
(
x
)
.
sum
(),
[
x
]))
# original cholesky doesn't show up in the grad (if mode="FAST_RUN"), so it does not raise
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
out
=
cholesky
(
x
,
lower
=
True
,
on_error
=
"nan"
)
chol_f
=
function
([
x
],
grad
(
out
.
sum
(),
[
x
]))
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
assert
np
.
all
(
np
.
isnan
(
chol_f
(
mat
)))
...
@@ -237,7 +241,7 @@ class TestSolveBase:
...
@@ -237,7 +241,7 @@ class TestSolveBase:
y
=
self
.
SolveTest
(
b_ndim
=
2
)(
A
,
b
)
y
=
self
.
SolveTest
(
b_ndim
=
2
)(
A
,
b
)
assert
(
assert
(
y
.
__repr__
()
y
.
__repr__
()
==
"SolveTest{lower=False,
check_finite=True,
b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
==
"SolveTest{lower=False, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
)
...
@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester):
...
@@ -549,7 +553,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def
test_repr
(
self
):
def
test_repr
(
self
):
assert
(
assert
(
repr
(
CholeskySolve
(
lower
=
True
,
b_ndim
=
1
))
repr
(
CholeskySolve
(
lower
=
True
,
b_ndim
=
1
))
==
"CholeskySolve(lower=True,
check_finite=True,
b_ndim=1,overwrite_b=False)"
==
"CholeskySolve(lower=True,b_ndim=1,overwrite_b=False)"
)
)
def
test_infer_shape
(
self
):
def
test_infer_shape
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论