Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bbe663d9
提交
bbe663d9
authored
2月 11, 2025
作者:
jessegrabowski
提交者:
Ricardo Vieira
2月 17, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement numba dispatch for all `linalg.solve` modes
上级
8e5e8a40
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
1749 行增加
和
350 行删除
+1749
-350
_LAPACK.py
pytensor/link/numba/dispatch/_LAPACK.py
+392
-0
basic.py
pytensor/link/numba/dispatch/basic.py
+1
-1
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+947
-195
slinalg.py
pytensor/tensor/slinalg.py
+21
-20
test_nlinalg.py
tests/link/numba/test_nlinalg.py
+1
-46
test_slinalg.py
tests/link/numba/test_slinalg.py
+327
-42
test_slinalg.py
tests/tensor/test_slinalg.py
+60
-46
没有找到文件。
pytensor/link/numba/dispatch/_LAPACK.py
0 → 100644
浏览文件 @
bbe663d9
import
ctypes
import
numpy
as
np
from
numba.core
import
cgutils
,
types
from
numba.core.extending
import
get_cython_function_address
,
intrinsic
from
numba.np.linalg
import
ensure_lapack
,
get_blas_kind
_PTR
=
ctypes
.
POINTER
_dbl
=
ctypes
.
c_double
_float
=
ctypes
.
c_float
_char
=
ctypes
.
c_char
_int
=
ctypes
.
c_int
_ptr_float
=
_PTR
(
_float
)
_ptr_dbl
=
_PTR
(
_dbl
)
_ptr_char
=
_PTR
(
_char
)
_ptr_int
=
_PTR
(
_int
)
def
_get_lapack_ptr_and_ptr_type
(
dtype
,
name
):
d
=
get_blas_kind
(
dtype
)
func_name
=
f
"{d}{name}"
float_pointer
=
_get_float_pointer_for_dtype
(
d
)
lapack_ptr
=
get_cython_function_address
(
"scipy.linalg.cython_lapack"
,
func_name
)
return
lapack_ptr
,
float_pointer
def
_get_underlying_float
(
dtype
):
s_dtype
=
str
(
dtype
)
out_type
=
s_dtype
if
s_dtype
==
"complex64"
:
out_type
=
"float32"
elif
s_dtype
==
"complex128"
:
out_type
=
"float64"
return
np
.
dtype
(
out_type
)
def
_get_float_pointer_for_dtype
(
blas_dtype
):
if
blas_dtype
in
[
"s"
,
"c"
]:
return
_ptr_float
elif
blas_dtype
in
[
"d"
,
"z"
]:
return
_ptr_dbl
def
_get_output_ctype
(
dtype
):
s_dtype
=
str
(
dtype
)
if
s_dtype
in
[
"float32"
,
"complex64"
]:
return
_float
elif
s_dtype
in
[
"float64"
,
"complex128"
]:
return
_dbl
@intrinsic
def
sptr_to_val
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
val
=
builder
.
load
(
args
[
0
])
return
val
sig
=
types
.
float32
(
types
.
CPointer
(
types
.
float32
))
return
sig
,
impl
@intrinsic
def
dptr_to_val
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
val
=
builder
.
load
(
args
[
0
])
return
val
sig
=
types
.
float64
(
types
.
CPointer
(
types
.
float64
))
return
sig
,
impl
@intrinsic
def
int_ptr_to_val
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
val
=
builder
.
load
(
args
[
0
])
return
val
sig
=
types
.
int32
(
types
.
CPointer
(
types
.
int32
))
return
sig
,
impl
@intrinsic
def
val_to_int_ptr
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
ptr
=
cgutils
.
alloca_once_value
(
builder
,
args
[
0
])
return
ptr
sig
=
types
.
CPointer
(
types
.
int32
)(
types
.
int32
)
return
sig
,
impl
@intrinsic
def
val_to_sptr
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
ptr
=
cgutils
.
alloca_once_value
(
builder
,
args
[
0
])
return
ptr
sig
=
types
.
CPointer
(
types
.
float32
)(
types
.
float32
)
return
sig
,
impl
@intrinsic
def
val_to_zptr
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
ptr
=
cgutils
.
alloca_once_value
(
builder
,
args
[
0
])
return
ptr
sig
=
types
.
CPointer
(
types
.
complex128
)(
types
.
complex128
)
return
sig
,
impl
@intrinsic
def
val_to_dptr
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
ptr
=
cgutils
.
alloca_once_value
(
builder
,
args
[
0
])
return
ptr
sig
=
types
.
CPointer
(
types
.
float64
)(
types
.
float64
)
return
sig
,
impl
class
_LAPACK
:
"""
Functions to return type signatures for wrapped LAPACK functions.
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
"""
def
__init__
(
self
):
ensure_lapack
()
@classmethod
def
numba_xtrtrs
(
cls
,
dtype
):
"""
Solve a triangular system of equations of the form A @ X = B or A.T @ X = B.
Called by scipy.linalg.solve_triangular
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"trtrs"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# UPLO
_ptr_int
,
# TRANS
_ptr_int
,
# DIAG
_ptr_int
,
# N
_ptr_int
,
# NRHS
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# B
_ptr_int
,
# LDB
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xpotrf
(
cls
,
dtype
):
"""
Compute the Cholesky factorization of a real symmetric positive definite matrix.
Called by scipy.linalg.cholesky
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"potrf"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# UPLO,
_ptr_int
,
# N
float_pointer
,
# A
_ptr_int
,
# LDA
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xpotrs
(
cls
,
dtype
):
"""
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
factorization computed by numba_potrf.
Called by scipy.linalg.cho_solve
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"potrs"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# UPLO
_ptr_int
,
# N
_ptr_int
,
# NRHS
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# B
_ptr_int
,
# LDB
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xlange
(
cls
,
dtype
):
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.
Called by scipy.linalg.solve
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"lange"
)
output_ctype
=
_get_output_ctype
(
dtype
)
functype
=
ctypes
.
CFUNCTYPE
(
output_ctype
,
# Output
_ptr_int
,
# NORM
_ptr_int
,
# M
_ptr_int
,
# N
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# WORK
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xlamch
(
cls
,
dtype
):
"""
Determine machine precision for floating point arithmetic.
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"lamch"
)
output_dtype
=
_get_output_ctype
(
dtype
)
functype
=
ctypes
.
CFUNCTYPE
(
output_dtype
,
# Output
_ptr_int
,
# CMACH
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xgecon
(
cls
,
dtype
):
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
Called by scipy.linalg.solve when assume_a == "gen"
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"gecon"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# NORM
_ptr_int
,
# N
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
_ptr_int
,
# IWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xgetrf
(
cls
,
dtype
):
"""
Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges.
Called by scipy.linalg.lu_factor
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"getrf"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# M
_ptr_int
,
# N
float_pointer
,
# A
_ptr_int
,
# LDA
_ptr_int
,
# IPIV
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xgetrs
(
cls
,
dtype
):
"""
Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU
factorization computed by GETRF.
Called by scipy.linalg.lu_solve
"""
...
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"getrs"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# TRANS
_ptr_int
,
# N
_ptr_int
,
# NRHS
float_pointer
,
# A
_ptr_int
,
# LDA
_ptr_int
,
# IPIV
float_pointer
,
# B
_ptr_int
,
# LDB
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xsysv
(
cls
,
dtype
):
"""
Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method,
factorizing A into LDL^T or UDU^T form, depending on the value of UPLO
Called by scipy.linalg.solve when assume_a == "sym"
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"sysv"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# UPLO
_ptr_int
,
# N
_ptr_int
,
# NRHS
float_pointer
,
# A
_ptr_int
,
# LDA
_ptr_int
,
# IPIV
float_pointer
,
# B
_ptr_int
,
# LDB
float_pointer
,
# WORK
_ptr_int
,
# LWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xsycon
(
cls
,
dtype
):
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"sycon"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# UPLO
_ptr_int
,
# N
float_pointer
,
# A
_ptr_int
,
# LDA
_ptr_int
,
# IPIV
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
_ptr_int
,
# IWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xpocon
(
cls
,
dtype
):
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.
Called by scipy.linalg.solve when assume_a == "pos"
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"pocon"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# UPLO
_ptr_int
,
# N
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# ANORM
float_pointer
,
# RCOND
float_pointer
,
# WORK
_ptr_int
,
# IWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xposv
(
cls
,
dtype
):
"""
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
factorization computed by potrf.
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"posv"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# UPLO
_ptr_int
,
# N
_ptr_int
,
# NRHS
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# B
_ptr_int
,
# LDB
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
pytensor/link/numba/dispatch/basic.py
浏览文件 @
bbe663d9
...
@@ -367,7 +367,7 @@ def numba_typify(data, dtype=None, **kwargs):
...
@@ -367,7 +367,7 @@ def numba_typify(data, dtype=None, **kwargs):
def
generate_fallback_impl
(
op
,
node
=
None
,
storage_map
=
None
,
**
kwargs
):
def
generate_fallback_impl
(
op
,
node
=
None
,
storage_map
=
None
,
**
kwargs
):
"""Create a Numba compatible function from a
n Aesara
`Op`."""
"""Create a Numba compatible function from a
Pytensor
`Op`."""
warnings
.
warn
(
warnings
.
warn
(
f
"Numba will use object mode to run {op}'s perform method"
,
f
"Numba will use object mode to run {op}'s perform method"
,
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
bbe663d9
import
ctypes
from
collections.abc
import
Callable
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
from
numba.core
import
cgutils
,
types
from
numba.core
import
types
from
numba.extending
import
get_cython_function_address
,
intrinsic
,
overload
from
numba.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
,
get_blas_kind
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numpy.linalg
import
LinAlgError
from
scipy
import
linalg
from
scipy
import
linalg
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.tensor.slinalg
import
BlockDiagonal
,
Cholesky
,
SolveTriangular
from
pytensor.tensor.slinalg
import
(
BlockDiagonal
,
Cholesky
,
CholeskySolve
,
Solve
,
SolveTriangular
,
)
_PTR
=
ctypes
.
POINTER
@numba_basic.numba_njit
(
inline
=
"always"
)
def
_solve_check
(
n
,
info
,
lamch
=
False
,
rcond
=
None
):
_dbl
=
ctypes
.
c_double
"""
_float
=
ctypes
.
c_float
Check arguments during the different steps of the solution phase
_char
=
ctypes
.
c_char
Adapted from https://github.com/scipy/scipy/blob/7f7f04caa4a55306a9c6613c89eef91fedbd72d4/scipy/linalg/_basic.py#L38
_int
=
ctypes
.
c_int
"""
if
info
<
0
:
_ptr_float
=
_PTR
(
_float
)
# TODO: figure out how to do an fstring here
_ptr_dbl
=
_PTR
(
_dbl
)
msg
=
"LAPACK reported an illegal value in input"
_ptr_char
=
_PTR
(
_char
)
raise
ValueError
(
msg
)
_ptr_int
=
_PTR
(
_int
)
elif
0
<
info
:
raise
LinAlgError
(
"Matrix is singular."
)
@numba.core.extending.register_jitable
if
lamch
:
def
_check_finite_matrix
(
a
,
func_name
):
E
=
_xlamch
(
"E"
)
for
v
in
np
.
nditer
(
a
)
:
if
rcond
<
E
:
if
not
np
.
isfinite
(
v
.
item
()):
# TODO: This should be a warning, but we can't raise warnings in numba mode
raise
np
.
linalg
.
LinAlgError
(
print
(
# noqa: T201
"
Non-numeric values (nan or inf) in input to "
+
func_name
"
Ill-conditioned matrix, rcond="
,
rcond
,
", result may not be accurate."
)
)
@intrinsic
def
val_to_dptr
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
ptr
=
cgutils
.
alloca_once_value
(
builder
,
args
[
0
])
return
ptr
sig
=
types
.
CPointer
(
types
.
float64
)(
types
.
float64
)
return
sig
,
impl
@intrinsic
def
val_to_zptr
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
ptr
=
cgutils
.
alloca_once_value
(
builder
,
args
[
0
])
return
ptr
sig
=
types
.
CPointer
(
types
.
complex128
)(
types
.
complex128
)
return
sig
,
impl
@intrinsic
def
val_to_sptr
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
ptr
=
cgutils
.
alloca_once_value
(
builder
,
args
[
0
])
return
ptr
sig
=
types
.
CPointer
(
types
.
float32
)(
types
.
float32
)
return
sig
,
impl
@intrinsic
def
val_to_int_ptr
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
ptr
=
cgutils
.
alloca_once_value
(
builder
,
args
[
0
])
return
ptr
sig
=
types
.
CPointer
(
types
.
int32
)(
types
.
int32
)
return
sig
,
impl
@intrinsic
def
int_ptr_to_val
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
val
=
builder
.
load
(
args
[
0
])
return
val
sig
=
types
.
int32
(
types
.
CPointer
(
types
.
int32
))
return
sig
,
impl
@intrinsic
def
dptr_to_val
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
val
=
builder
.
load
(
args
[
0
])
return
val
sig
=
types
.
float64
(
types
.
CPointer
(
types
.
float64
))
return
sig
,
impl
@intrinsic
def
sptr_to_val
(
typingctx
,
data
):
def
impl
(
context
,
builder
,
signature
,
args
):
val
=
builder
.
load
(
args
[
0
])
return
val
sig
=
types
.
float32
(
types
.
CPointer
(
types
.
float32
))
return
sig
,
impl
def
_get_float_pointer_for_dtype
(
blas_dtype
):
if
blas_dtype
in
[
"s"
,
"c"
]:
return
_ptr_float
elif
blas_dtype
in
[
"d"
,
"z"
]:
return
_ptr_dbl
def
_get_underlying_float
(
dtype
):
s_dtype
=
str
(
dtype
)
out_type
=
s_dtype
if
s_dtype
==
"complex64"
:
out_type
=
"float32"
elif
s_dtype
==
"complex128"
:
out_type
=
"float64"
return
np
.
dtype
(
out_type
)
def
_get_lapack_ptr_and_ptr_type
(
dtype
,
name
):
d
=
get_blas_kind
(
dtype
)
func_name
=
f
"{d}{name}"
float_pointer
=
_get_float_pointer_for_dtype
(
d
)
lapack_ptr
=
get_cython_function_address
(
"scipy.linalg.cython_lapack"
,
func_name
)
return
lapack_ptr
,
float_pointer
def
_check_scipy_linalg_matrix
(
a
,
func_name
):
def
_check_scipy_linalg_matrix
(
a
,
func_name
):
"""
"""
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
...
@@ -152,64 +68,50 @@ def _check_scipy_linalg_matrix(a, func_name):
...
@@ -152,64 +68,50 @@ def _check_scipy_linalg_matrix(a, func_name):
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
raise
numba
.
TypingError
(
msg
,
highlighting
=
False
)
class
_LAPACK
:
def
_solve_triangular
(
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
,
b_ndim
=
1
,
overwrite_b
=
False
):
"""
"""
Functions to return type signatures for wrapped LAPACK functions
.
Thin wrapper around scipy.linalg.solve_triangular
.
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
This function is overloaded instead of the original scipy function to avoid unexpected side-effects to users who
"""
import pytensor.
def
__init__
(
self
):
The signature must be the same as solve_triangular_impl, so b_ndim is included, although this argument is not
ensure_lapack
()
used by scipy.linalg.solve_triangular.
@classmethod
def
numba_xtrtrs
(
cls
,
dtype
):
"""
"""
Called by scipy.linalg.solve_triangular
return
linalg
.
solve_triangular
(
"""
A
,
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"trtrs"
)
B
,
trans
=
trans
,
functype
=
ctypes
.
CFUNCTYPE
(
lower
=
lower
,
None
,
unit_diagonal
=
unit_diagonal
,
_ptr_int
,
# UPLO
overwrite_b
=
overwrite_b
,
_ptr_int
,
# TRANS
_ptr_int
,
# DIAG
_ptr_int
,
# N
_ptr_int
,
# NRHS
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# B
_ptr_int
,
# LDB
_ptr_int
,
# INFO
)
)
return
functype
(
lapack_ptr
)
@classmethod
@numba_basic.numba_njit
(
inline
=
"always"
)
def
numba_xpotrf
(
cls
,
dtype
):
def
_trans_char_to_int
(
trans
):
"""
if
trans
not
in
[
0
,
1
,
2
]:
Called by scipy.linalg.cholesky
raise
ValueError
(
'Parameter "trans" should be one of 0, 1, 2'
)
"""
if
trans
==
0
:
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"potrf"
)
return
ord
(
"N"
)
functype
=
ctypes
.
CFUNCTYPE
(
elif
trans
==
1
:
None
,
return
ord
(
"T"
)
_ptr_int
,
# UPLO,
else
:
_ptr_int
,
# N
return
ord
(
"C"
)
float_pointer
,
# A
_ptr_int
,
# LDA
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
def
_solve_triangular
(
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
):
@numba_basic.numba_njit
(
inline
=
"always"
)
return
linalg
.
solve_triangular
(
def
_solve_check_input_shapes
(
A
,
B
):
A
,
B
,
trans
=
trans
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
if
A
.
shape
[
0
]
!=
B
.
shape
[
0
]:
)
raise
linalg
.
LinAlgError
(
"Dimensions of A and B do not conform"
)
if
A
.
shape
[
-
2
]
!=
A
.
shape
[
-
1
]:
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
@overload
(
_solve_triangular
)
@overload
(
_solve_triangular
)
def
solve_triangular_impl
(
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
):
def
solve_triangular_impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
ensure_lapack
()
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve_triangular"
)
_check_scipy_linalg_matrix
(
A
,
"solve_triangular"
)
...
@@ -218,37 +120,27 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
...
@@ -218,37 +120,27 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
w_type
=
_get_underlying_float
(
dtype
)
w_type
=
_get_underlying_float
(
dtype
)
numba_trtrs
=
_LAPACK
()
.
numba_xtrtrs
(
dtype
)
numba_trtrs
=
_LAPACK
()
.
numba_xtrtrs
(
dtype
)
def
impl
(
A
,
B
,
trans
=
0
,
lower
=
False
,
unit_diagonal
=
False
):
def
impl
(
A
,
B
,
trans
,
lower
,
unit_diagonal
,
b_ndim
,
overwrite_b
):
B_is_1d
=
B
.
ndim
==
1
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
if
A
.
shape
[
-
2
]
!=
_N
:
_solve_check_input_shapes
(
A
,
B
)
raise
linalg
.
LinAlgError
(
"Last 2 dimensions of A must be square"
)
if
A
.
shape
[
0
]
!=
B
.
shape
[
0
]:
B_is_1d
=
B
.
ndim
==
1
raise
linalg
.
LinAlgError
(
"Dimensions of A and B do not conform"
)
if
B_is_1d
:
if
not
overwrite_b
:
B_copy
=
np
.
asfortranarray
(
np
.
expand_dims
(
B
,
-
1
))
else
:
B_copy
=
_copy_to_fortran_order
(
B
)
B_copy
=
_copy_to_fortran_order
(
B
)
if
trans
not
in
[
0
,
1
,
2
]:
raise
ValueError
(
'Parameter "trans" should be one of N, C, T or 0, 1, 2'
)
if
trans
==
0
:
transval
=
ord
(
"N"
)
elif
trans
==
1
:
transval
=
ord
(
"T"
)
else
:
else
:
transval
=
ord
(
"C"
)
B_copy
=
B
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B
,
-
1
)
B_NDIM
=
1
if
B_is_1d
else
int
(
B
.
shape
[
1
])
NRHS
=
1
if
B_is_1d
else
int
(
B_copy
.
shape
[
-
1
])
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
TRANS
=
val_to_int_ptr
(
transval
)
TRANS
=
val_to_int_ptr
(
_trans_char_to_int
(
trans
)
)
DIAG
=
val_to_int_ptr
(
ord
(
"U"
)
if
unit_diagonal
else
ord
(
"N"
))
DIAG
=
val_to_int_ptr
(
ord
(
"U"
)
if
unit_diagonal
else
ord
(
"N"
))
N
=
val_to_int_ptr
(
_N
)
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
B_NDIM
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
INFO
=
val_to_int_ptr
(
0
)
...
@@ -266,19 +158,24 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
...
@@ -266,19 +158,24 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
INFO
,
INFO
,
)
)
_solve_check
(
int_ptr_to_val
(
LDA
),
int_ptr_to_val
(
INFO
))
if
B_is_1d
:
if
B_is_1d
:
return
B_copy
[
...
,
0
],
int_ptr_to_val
(
INFO
)
return
B_copy
[
...
,
0
]
return
B_copy
,
int_ptr_to_val
(
INFO
)
return
B_copy
return
impl
return
impl
@numba_funcify.register
(
SolveTriangular
)
@numba_funcify.register
(
SolveTriangular
)
def
numba_funcify_SolveTriangular
(
op
,
node
,
**
kwargs
):
def
numba_funcify_SolveTriangular
(
op
,
node
,
**
kwargs
):
trans
=
op
.
trans
trans
=
bool
(
op
.
trans
)
lower
=
op
.
lower
lower
=
op
.
lower
unit_diagonal
=
op
.
unit_diagonal
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
check_finite
=
op
.
check_finite
overwrite_b
=
op
.
overwrite_b
b_ndim
=
op
.
b_ndim
dtype
=
node
.
inputs
[
0
]
.
dtype
dtype
=
node
.
inputs
[
0
]
.
dtype
if
str
(
dtype
)
.
startswith
(
"complex"
):
if
str
(
dtype
)
.
startswith
(
"complex"
):
...
@@ -298,11 +195,16 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
...
@@ -298,11 +195,16 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
"Non-numeric values (nan or inf) in input b to solve_triangular"
"Non-numeric values (nan or inf) in input b to solve_triangular"
)
)
res
,
info
=
_solve_triangular
(
a
,
b
,
trans
,
lower
,
unit_diagonal
)
res
=
_solve_triangular
(
if
info
!=
0
:
a
,
raise
np
.
linalg
.
LinAlgError
(
b
,
"Singular matrix in input A to solve_triangular"
trans
=
trans
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
b_ndim
=
b_ndim
,
)
)
return
res
return
res
return
solve_triangular
return
solve_triangular
...
@@ -429,3 +331,853 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
...
@@ -429,3 +331,853 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
return
out
return
out
return
block_diag
return
block_diag
def
_xlamch
(
kind
:
str
=
"E"
):
"""
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
"""
pass
@overload
(
_xlamch
)
def
xlamch_impl
(
kind
:
str
=
"E"
)
->
Callable
[[
str
],
float
]:
"""
Compute the machine precision for a given floating point type.
"""
from
pytensor
import
config
ensure_lapack
()
w_type
=
_get_underlying_float
(
config
.
floatX
)
if
w_type
==
"float32"
:
dtype
=
types
.
float32
elif
w_type
==
"float64"
:
dtype
=
types
.
float64
else
:
raise
NotImplementedError
(
"Unsupported dtype"
)
numba_lamch
=
_LAPACK
()
.
numba_xlamch
(
dtype
)
def
impl
(
kind
:
str
=
"E"
)
->
float
:
KIND
=
val_to_int_ptr
(
ord
(
kind
))
return
numba_lamch
(
KIND
)
# type: ignore
return
impl
def
_xlange
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
float
:
"""
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
"""
return
# type: ignore
@overload
(
_xlange
)
def
xlange_impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
)
->
Callable
[[
np
.
ndarray
,
str
],
float
]:
"""
xLANGE returns the value of the one norm, or the Frobenius norm, or the infinity norm, or the element of
largest absolute value of a matrix A.
"""
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"norm"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_lange
=
_LAPACK
()
.
numba_xlange
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
order
:
str
|
None
=
None
):
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
NORM
=
(
val_to_int_ptr
(
ord
(
order
))
if
order
is
not
None
else
val_to_int_ptr
(
ord
(
"1"
))
)
WORK
=
np
.
empty
(
_M
,
dtype
=
dtype
)
# type: ignore
result
=
numba_lange
(
NORM
,
M
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
WORK
.
view
(
w_type
)
.
ctypes
)
return
result
return
impl
def
_xgecon
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
graphs.
"""
return
# type: ignore
@overload
(
_xgecon
)
def
xgecon_impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
Callable
[[
np
.
ndarray
,
float
,
str
],
tuple
[
np
.
ndarray
,
int
]]:
"""
Compute the condition number of a matrix A.
"""
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"gecon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_gecon
=
_LAPACK
()
.
numba_xgecon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
A_norm
:
float
,
norm
:
str
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
A_NORM
=
np
.
array
(
A_norm
,
dtype
=
dtype
)
NORM
=
val_to_int_ptr
(
ord
(
norm
))
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
4
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
1
)
numba_gecon
(
NORM
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
A_NORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_getrf
(
A
,
overwrite_a
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for LU factorization; used by linalg.solve.
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
"""
return
# type: ignore
@overload
(
_getrf
)
def
getrf_impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
Callable
[[
np
.
ndarray
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"getrf"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrf
=
_LAPACK
()
.
numba_xgetrf
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
if
not
overwrite_a
:
A_copy
=
_copy_to_fortran_order
(
A
)
else
:
A_copy
=
A
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
IPIV
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
# type: ignore
INFO
=
val_to_int_ptr
(
0
)
numba_getrf
(
M
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
INFO
)
return
A_copy
,
IPIV
,
int_ptr_to_val
(
INFO
)
return
impl
def
_getrs
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve.
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
"""
return
# type: ignore
@overload
(
_getrs
)
def
getrs_impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
,
bool
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
LU
,
"getrs"
)
_check_scipy_linalg_matrix
(
B
,
"getrs"
)
dtype
=
LU
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrs
=
_LAPACK
()
.
numba_xgetrs
(
dtype
)
def
impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
LU
.
shape
[
-
1
])
_solve_check_input_shapes
(
LU
,
B
)
B_is_1d
=
B
.
ndim
==
1
if
not
overwrite_b
:
B_copy
=
_copy_to_fortran_order
(
B
)
else
:
B_copy
=
B
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B_copy
.
shape
[
-
1
])
TRANS
=
val_to_int_ptr
(
_trans_char_to_int
(
trans
))
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
IPIV
=
_copy_to_fortran_order
(
IPIV
)
INFO
=
val_to_int_ptr
(
0
)
numba_getrs
(
TRANS
,
N
,
NRHS
,
LU
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
if
B_is_1d
:
return
B_copy
[
...
,
0
],
int_ptr_to_val
(
INFO
)
return
B_copy
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_gen
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
for users who import pytensor."""
return
linalg
.
solve
(
A
,
B
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
assume_a
=
"gen"
,
transposed
=
transposed
,
)
@overload
(
_solve_gen
)
def
solve_gen_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
order
=
"I"
if
transposed
else
"1"
norm
=
_xlange
(
A
,
order
=
order
)
N
=
A
.
shape
[
1
]
LU
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
_solve_check
(
N
,
INFO
)
X
,
INFO
=
_getrs
(
LU
=
LU
,
B
=
B
,
IPIV
=
IPIV
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
)
_solve_check
(
N
,
INFO
)
RCOND
,
INFO
=
_xgecon
(
LU
,
norm
,
"1"
)
_solve_check
(
N
,
INFO
,
True
,
RCOND
)
return
X
return
impl
def
_sysv
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
"""
return
# type: ignore
@overload
(
_sysv
)
def
sysv_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"sysv"
)
_check_scipy_linalg_matrix
(
B
,
"sysv"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_sysv
=
_LAPACK
()
.
numba_xsysv
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
):
_LDA
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
_solve_check_input_shapes
(
A
,
B
)
if
not
overwrite_a
:
A_copy
=
_copy_to_fortran_order
(
A
)
else
:
A_copy
=
A
B_is_1d
=
B
.
ndim
==
1
if
not
overwrite_b
:
B_copy
=
_copy_to_fortran_order
(
B
)
else
:
B_copy
=
B
if
B_is_1d
:
B_copy
=
np
.
asfortranarray
(
np
.
expand_dims
(
B_copy
,
-
1
))
NRHS
=
1
if
B_is_1d
else
int
(
B
.
shape
[
-
1
])
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
N
=
val_to_int_ptr
(
_N
)
# type: ignore
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_LDA
)
# type: ignore
IPIV
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
# type: ignore
LDB
=
val_to_int_ptr
(
_N
)
# type: ignore
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
-
1
)
INFO
=
val_to_int_ptr
(
0
)
# Workspace query
numba_sysv
(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
WORK
.
view
(
w_type
)
.
ctypes
,
LWORK
,
INFO
,
)
WS_SIZE
=
np
.
int32
(
WORK
[
0
]
.
real
)
LWORK
=
val_to_int_ptr
(
WS_SIZE
)
WORK
=
np
.
empty
(
WS_SIZE
,
dtype
=
dtype
)
# Actual solve
numba_sysv
(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
WORK
.
view
(
w_type
)
.
ctypes
,
LWORK
,
INFO
,
)
if
B_is_1d
:
return
B_copy
[
...
,
0
],
IPIV
,
int_ptr_to_val
(
INFO
)
return
B_copy
,
IPIV
,
int_ptr_to_val
(
INFO
)
return
impl
def
_sycon
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
python mode.
"""
return
# type: ignore
@overload
(
_sycon
)
def
sycon_impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"sycon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_sycon
=
_LAPACK
()
.
numba_xsycon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
ipiv
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
))
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
2
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_sycon
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
ipiv
.
ctypes
,
ANORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_symmetric
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
unexpected side-effects when users import pytensor."""
return
linalg
.
solve
(
A
,
B
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
assume_a
=
"sym"
,
transposed
=
transposed
,
)
@overload
(
_solve_symmetric
)
def
solve_symmetric_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
x
,
ipiv
,
info
=
_sysv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
rcond
,
info
=
_sycon
(
A
,
ipiv
,
_xlange
(
A
,
order
=
"I"
))
_solve_check
(
A
.
shape
[
-
1
],
info
,
True
,
rcond
)
return
x
return
impl
def
_posv
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
"""
return
# type: ignore
@overload
(
_posv
)
def
posv_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
int
]
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_posv
=
_LAPACK
()
.
numba_xposv
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
tuple
[
np
.
ndarray
,
int
]:
_solve_check_input_shapes
(
A
,
B
)
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
if
not
overwrite_a
:
A_copy
=
_copy_to_fortran_order
(
A
)
else
:
A_copy
=
A
B_is_1d
=
B
.
ndim
==
1
if
not
overwrite_b
:
B_copy
=
_copy_to_fortran_order
(
B
)
else
:
B_copy
=
B
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
NRHS
=
1
if
B_is_1d
else
int
(
B
.
shape
[
-
1
])
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
numba_posv
(
UPLO
,
N
,
NRHS
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
if
B_is_1d
:
return
B_copy
[
...
,
0
],
int_ptr_to_val
(
INFO
)
return
B_copy
,
int_ptr_to_val
(
INFO
)
return
impl
def
_pocon
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
linalg.solve when assume_a = "pos".
"""
return
# type: ignore
@overload
(
_pocon
)
def
pocon_impl
(
A
:
np
.
ndarray
,
anorm
:
float
)
->
Callable
[[
np
.
ndarray
,
float
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"pocon"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_pocon
=
_LAPACK
()
.
numba_xpocon
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
anorm
:
float
):
_N
=
np
.
int32
(
A
.
shape
[
-
1
])
A_copy
=
_copy_to_fortran_order
(
A
)
UPLO
=
val_to_int_ptr
(
ord
(
"L"
))
N
=
val_to_int_ptr
(
_N
)
LDA
=
val_to_int_ptr
(
_N
)
ANORM
=
np
.
array
(
anorm
,
dtype
=
dtype
)
RCOND
=
np
.
empty
(
1
,
dtype
=
dtype
)
WORK
=
np
.
empty
(
3
*
_N
,
dtype
=
dtype
)
IWORK
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
INFO
=
val_to_int_ptr
(
0
)
numba_pocon
(
UPLO
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
ANORM
.
view
(
w_type
)
.
ctypes
,
RCOND
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
IWORK
.
ctypes
,
INFO
,
)
return
RCOND
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_psd
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
avoid unexpected side-effects when users import pytensor."""
return
linalg
.
solve
(
A
,
B
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
transposed
=
transposed
,
assume_a
=
"pos"
,
)
@overload
(
_solve_psd
)
def
solve_psd_impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
def
impl
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
np
.
ndarray
:
_solve_check_input_shapes
(
A
,
B
)
x
,
info
=
_posv
(
A
,
B
,
lower
,
overwrite_a
,
overwrite_b
,
check_finite
,
transposed
)
_solve_check
(
A
.
shape
[
-
1
],
info
)
rcond
,
info
=
_pocon
(
x
,
_xlange
(
A
))
_solve_check
(
A
.
shape
[
-
1
],
info
=
info
,
lamch
=
True
,
rcond
=
rcond
)
return
x
return
impl
@numba_funcify.register
(
Solve
)
def
numba_funcify_Solve
(
op
,
node
,
**
kwargs
):
assume_a
=
op
.
assume_a
lower
=
op
.
lower
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
overwrite_b
=
op
.
overwrite_b
transposed
=
False
# TODO: Solve doesnt currently allow the transposed argument
dtype
=
node
.
inputs
[
0
]
.
dtype
if
str
(
dtype
)
.
startswith
(
"complex"
):
raise
NotImplementedError
(
"Complex inputs not currently supported by solve in Numba mode"
)
if
assume_a
==
"gen"
:
solve_fn
=
_solve_gen
elif
assume_a
==
"sym"
:
solve_fn
=
_solve_symmetric
elif
assume_a
==
"her"
:
raise
NotImplementedError
(
'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, '
"please open an issue on github."
)
elif
assume_a
==
"pos"
:
solve_fn
=
_solve_psd
else
:
raise
NotImplementedError
(
f
"Assumption {assume_a} not supported in Numba mode"
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
solve
(
a
,
b
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to solve"
)
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to solve"
)
res
=
solve_fn
(
a
,
b
,
lower
,
overwrite_a
,
overwrite_b
,
check_finite
,
transposed
)
return
res
return
solve
def
_cho_solve
(
A_and_lower
,
B
,
overwrite_a
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
A
,
lower
=
A_and_lower
return
linalg
.
cho_solve
((
A
,
lower
),
B
)
@overload
(
_cho_solve
)
def
cho_solve_impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
ensure_lapack
()
_check_scipy_linalg_matrix
(
C
,
"cho_solve"
)
_check_scipy_linalg_matrix
(
B
,
"cho_solve"
)
dtype
=
C
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_potrs
=
_LAPACK
()
.
numba_xpotrs
(
dtype
)
def
impl
(
C
,
B
,
lower
=
False
,
overwrite_b
=
False
,
check_finite
=
True
):
_solve_check_input_shapes
(
C
,
B
)
_N
=
np
.
int32
(
C
.
shape
[
-
1
])
C_copy
=
_copy_to_fortran_order
(
C
)
B_is_1d
=
B
.
ndim
==
1
if
B_is_1d
:
B_copy
=
np
.
asfortranarray
(
np
.
expand_dims
(
B
,
-
1
))
else
:
B_copy
=
_copy_to_fortran_order
(
B
)
NRHS
=
1
if
B_is_1d
else
int
(
B
.
shape
[
-
1
])
UPLO
=
val_to_int_ptr
(
ord
(
"L"
)
if
lower
else
ord
(
"U"
))
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
INFO
=
val_to_int_ptr
(
0
)
numba_potrs
(
UPLO
,
N
,
NRHS
,
C_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
if
B_is_1d
:
return
B_copy
[
...
,
0
],
int_ptr_to_val
(
INFO
)
return
B_copy
,
int_ptr_to_val
(
INFO
)
return
impl
@numba_funcify.register
(
CholeskySolve
)
def
numba_funcify_CholeskySolve
(
op
,
node
,
**
kwargs
):
lower
=
op
.
lower
overwrite_b
=
op
.
overwrite_b
check_finite
=
op
.
check_finite
dtype
=
node
.
inputs
[
0
]
.
dtype
if
str
(
dtype
)
.
startswith
(
"complex"
):
raise
NotImplementedError
(
"Complex inputs not currently supported by cho_solve in Numba mode"
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
cho_solve
(
c
,
b
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
c
),
np
.
isnan
(
c
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input A to cho_solve"
)
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
b
),
np
.
isnan
(
b
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) in input b to cho_solve"
)
res
,
info
=
_cho_solve
(
c
,
b
,
lower
=
lower
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
)
if
info
<
0
:
raise
np
.
linalg
.
LinAlgError
(
"Illegal values found in input to cho_solve"
)
elif
info
>
0
:
raise
np
.
linalg
.
LinAlgError
(
"Matrix is not positive definite in input to cho_solve"
)
return
res
return
cho_solve
pytensor/tensor/slinalg.py
浏览文件 @
bbe663d9
import
logging
import
logging
import
typing
import
warnings
import
warnings
from
collections.abc
import
Sequence
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
Literal
,
cast
from
typing
import
Literal
,
cast
import
numpy
as
np
import
numpy
as
np
import
scipy.linalg
import
scipy.linalg
as
scipy_linalg
import
pytensor
import
pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
...
@@ -58,7 +58,7 @@ class Cholesky(Op):
...
@@ -58,7 +58,7 @@ class Cholesky(Op):
f
"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
f
"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
)
)
# Call scipy to find output dtype
# Call scipy to find output dtype
dtype
=
scipy
.
linalg
.
cholesky
(
np
.
eye
(
1
,
dtype
=
x
.
type
.
dtype
))
.
dtype
dtype
=
scipy
_
linalg
.
cholesky
(
np
.
eye
(
1
,
dtype
=
x
.
type
.
dtype
))
.
dtype
return
Apply
(
self
,
[
x
],
[
tensor
(
shape
=
x
.
type
.
shape
,
dtype
=
dtype
)])
return
Apply
(
self
,
[
x
],
[
tensor
(
shape
=
x
.
type
.
shape
,
dtype
=
dtype
)])
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
...
@@ -68,21 +68,21 @@ class Cholesky(Op):
...
@@ -68,21 +68,21 @@ class Cholesky(Op):
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if
self
.
overwrite_a
and
x
.
flags
[
"C_CONTIGUOUS"
]:
if
self
.
overwrite_a
and
x
.
flags
[
"C_CONTIGUOUS"
]:
out
[
0
]
=
scipy
.
linalg
.
cholesky
(
out
[
0
]
=
scipy
_
linalg
.
cholesky
(
x
.
T
,
x
.
T
,
lower
=
not
self
.
lower
,
lower
=
not
self
.
lower
,
check_finite
=
self
.
check_finite
,
check_finite
=
self
.
check_finite
,
overwrite_a
=
True
,
overwrite_a
=
True
,
)
.
T
)
.
T
else
:
else
:
out
[
0
]
=
scipy
.
linalg
.
cholesky
(
out
[
0
]
=
scipy
_
linalg
.
cholesky
(
x
,
x
,
lower
=
self
.
lower
,
lower
=
self
.
lower
,
check_finite
=
self
.
check_finite
,
check_finite
=
self
.
check_finite
,
overwrite_a
=
self
.
overwrite_a
,
overwrite_a
=
self
.
overwrite_a
,
)
)
except
scipy
.
linalg
.
LinAlgError
:
except
scipy
_
linalg
.
LinAlgError
:
if
self
.
on_error
==
"raise"
:
if
self
.
on_error
==
"raise"
:
raise
raise
else
:
else
:
...
@@ -334,7 +334,7 @@ class CholeskySolve(SolveBase):
...
@@ -334,7 +334,7 @@ class CholeskySolve(SolveBase):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
C
,
b
=
inputs
C
,
b
=
inputs
rval
=
scipy
.
linalg
.
cho_solve
(
rval
=
scipy
_
linalg
.
cho_solve
(
(
C
,
self
.
lower
),
(
C
,
self
.
lower
),
b
,
b
,
check_finite
=
self
.
check_finite
,
check_finite
=
self
.
check_finite
,
...
@@ -401,7 +401,7 @@ class SolveTriangular(SolveBase):
...
@@ -401,7 +401,7 @@ class SolveTriangular(SolveBase):
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
A
,
b
=
inputs
A
,
b
=
inputs
outputs
[
0
][
0
]
=
scipy
.
linalg
.
solve_triangular
(
outputs
[
0
][
0
]
=
scipy
_
linalg
.
solve_triangular
(
A
,
A
,
b
,
b
,
lower
=
self
.
lower
,
lower
=
self
.
lower
,
...
@@ -502,7 +502,7 @@ class Solve(SolveBase):
...
@@ -502,7 +502,7 @@ class Solve(SolveBase):
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
a
,
b
=
inputs
a
,
b
=
inputs
outputs
[
0
][
0
]
=
scipy
.
linalg
.
solve
(
outputs
[
0
][
0
]
=
scipy
_
linalg
.
solve
(
a
=
a
,
a
=
a
,
b
=
b
,
b
=
b
,
lower
=
self
.
lower
,
lower
=
self
.
lower
,
...
@@ -619,9 +619,9 @@ class Eigvalsh(Op):
...
@@ -619,9 +619,9 @@ class Eigvalsh(Op):
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
(
w
,)
=
outputs
(
w
,)
=
outputs
if
len
(
inputs
)
==
2
:
if
len
(
inputs
)
==
2
:
w
[
0
]
=
scipy
.
linalg
.
eigvalsh
(
a
=
inputs
[
0
],
b
=
inputs
[
1
],
lower
=
self
.
lower
)
w
[
0
]
=
scipy
_
linalg
.
eigvalsh
(
a
=
inputs
[
0
],
b
=
inputs
[
1
],
lower
=
self
.
lower
)
else
:
else
:
w
[
0
]
=
scipy
.
linalg
.
eigvalsh
(
a
=
inputs
[
0
],
b
=
None
,
lower
=
self
.
lower
)
w
[
0
]
=
scipy
_
linalg
.
eigvalsh
(
a
=
inputs
[
0
],
b
=
None
,
lower
=
self
.
lower
)
def
grad
(
self
,
inputs
,
g_outputs
):
def
grad
(
self
,
inputs
,
g_outputs
):
a
,
b
=
inputs
a
,
b
=
inputs
...
@@ -675,7 +675,7 @@ class EigvalshGrad(Op):
...
@@ -675,7 +675,7 @@ class EigvalshGrad(Op):
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
(
a
,
b
,
gw
)
=
inputs
(
a
,
b
,
gw
)
=
inputs
w
,
v
=
scipy
.
linalg
.
eigh
(
a
,
b
,
lower
=
self
.
lower
)
w
,
v
=
scipy
_
linalg
.
eigh
(
a
,
b
,
lower
=
self
.
lower
)
gA
=
v
.
dot
(
np
.
diag
(
gw
)
.
dot
(
v
.
T
))
gA
=
v
.
dot
(
np
.
diag
(
gw
)
.
dot
(
v
.
T
))
gB
=
-
v
.
dot
(
np
.
diag
(
gw
*
w
)
.
dot
(
v
.
T
))
gB
=
-
v
.
dot
(
np
.
diag
(
gw
*
w
)
.
dot
(
v
.
T
))
...
@@ -718,7 +718,7 @@ class Expm(Op):
...
@@ -718,7 +718,7 @@ class Expm(Op):
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
(
A
,)
=
inputs
(
A
,)
=
inputs
(
expm
,)
=
outputs
(
expm
,)
=
outputs
expm
[
0
]
=
scipy
.
linalg
.
expm
(
A
)
expm
[
0
]
=
scipy
_
linalg
.
expm
(
A
)
def
grad
(
self
,
inputs
,
outputs
):
def
grad
(
self
,
inputs
,
outputs
):
(
A
,)
=
inputs
(
A
,)
=
inputs
...
@@ -758,8 +758,8 @@ class ExpmGrad(Op):
...
@@ -758,8 +758,8 @@ class ExpmGrad(Op):
# this expression.
# this expression.
(
A
,
gA
)
=
inputs
(
A
,
gA
)
=
inputs
(
out
,)
=
outputs
(
out
,)
=
outputs
w
,
V
=
scipy
.
linalg
.
eig
(
A
,
right
=
True
)
w
,
V
=
scipy
_
linalg
.
eig
(
A
,
right
=
True
)
U
=
scipy
.
linalg
.
inv
(
V
)
.
T
U
=
scipy
_
linalg
.
inv
(
V
)
.
T
exp_w
=
np
.
exp
(
w
)
exp_w
=
np
.
exp
(
w
)
X
=
np
.
subtract
.
outer
(
exp_w
,
exp_w
)
/
np
.
subtract
.
outer
(
w
,
w
)
X
=
np
.
subtract
.
outer
(
exp_w
,
exp_w
)
/
np
.
subtract
.
outer
(
w
,
w
)
...
@@ -800,7 +800,7 @@ class SolveContinuousLyapunov(Op):
...
@@ -800,7 +800,7 @@ class SolveContinuousLyapunov(Op):
X
=
output_storage
[
0
]
X
=
output_storage
[
0
]
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
X
[
0
]
=
scipy
.
linalg
.
solve_continuous_lyapunov
(
A
,
B
)
.
astype
(
out_dtype
)
X
[
0
]
=
scipy
_
linalg
.
solve_continuous_lyapunov
(
A
,
B
)
.
astype
(
out_dtype
)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
return
[
shapes
[
0
]]
...
@@ -870,7 +870,7 @@ class BilinearSolveDiscreteLyapunov(Op):
...
@@ -870,7 +870,7 @@ class BilinearSolveDiscreteLyapunov(Op):
X
=
output_storage
[
0
]
X
=
output_storage
[
0
]
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
X
[
0
]
=
scipy
.
linalg
.
solve_discrete_lyapunov
(
A
,
B
,
method
=
"bilinear"
)
.
astype
(
X
[
0
]
=
scipy
_
linalg
.
solve_discrete_lyapunov
(
A
,
B
,
method
=
"bilinear"
)
.
astype
(
out_dtype
out_dtype
)
)
...
@@ -992,7 +992,7 @@ class SolveDiscreteARE(Op):
...
@@ -992,7 +992,7 @@ class SolveDiscreteARE(Op):
Q
=
0.5
*
(
Q
+
Q
.
T
)
Q
=
0.5
*
(
Q
+
Q
.
T
)
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
X
[
0
]
=
scipy
.
linalg
.
solve_discrete_are
(
A
,
B
,
Q
,
R
)
.
astype
(
out_dtype
)
X
[
0
]
=
scipy
_
linalg
.
solve_discrete_are
(
A
,
B
,
Q
,
R
)
.
astype
(
out_dtype
)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
return
[
shapes
[
0
]]
...
@@ -1064,7 +1064,7 @@ def solve_discrete_are(
...
@@ -1064,7 +1064,7 @@ def solve_discrete_are(
)
)
def
_largest_common_dtype
(
tensors
:
typing
.
Sequence
[
TensorVariable
])
->
np
.
dtype
:
def
_largest_common_dtype
(
tensors
:
Sequence
[
TensorVariable
])
->
np
.
dtype
:
return
reduce
(
lambda
l
,
r
:
np
.
promote_types
(
l
,
r
),
[
x
.
dtype
for
x
in
tensors
])
return
reduce
(
lambda
l
,
r
:
np
.
promote_types
(
l
,
r
),
[
x
.
dtype
for
x
in
tensors
])
...
@@ -1118,7 +1118,7 @@ class BlockDiagonal(BaseBlockDiagonal):
...
@@ -1118,7 +1118,7 @@ class BlockDiagonal(BaseBlockDiagonal):
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
output_storage
[
0
][
0
]
=
scipy
.
linalg
.
block_diag
(
*
inputs
)
.
astype
(
dtype
)
output_storage
[
0
][
0
]
=
scipy
_
linalg
.
block_diag
(
*
inputs
)
.
astype
(
dtype
)
def
block_diag
(
*
matrices
:
TensorVariable
):
def
block_diag
(
*
matrices
:
TensorVariable
):
...
@@ -1175,4 +1175,5 @@ __all__ = [
...
@@ -1175,4 +1175,5 @@ __all__ = [
"solve_discrete_are"
,
"solve_discrete_are"
,
"solve_triangular"
,
"solve_triangular"
,
"block_diag"
,
"block_diag"
,
"cho_solve"
,
]
]
tests/link/numba/test_nlinalg.py
浏览文件 @
bbe663d9
...
@@ -7,58 +7,13 @@ import pytensor.tensor as pt
...
@@ -7,58 +7,13 @@ import pytensor.tensor as pt
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
nlinalg
,
slinalg
from
pytensor.tensor
import
nlinalg
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
set_test_value
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
set_test_value
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
@pytest.mark.parametrize
(
"A, x, lower, exc"
,
[
(
set_test_value
(
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
set_test_value
(
pt
.
dvector
(),
rng
.
random
(
size
=
(
3
,))
.
astype
(
"float64"
)),
"gen"
,
None
,
),
(
set_test_value
(
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
),
),
set_test_value
(
pt
.
dvector
(),
rng
.
random
(
size
=
(
3
,))
.
astype
(
"float64"
)),
"gen"
,
None
,
),
],
)
def
test_Solve
(
A
,
x
,
lower
,
exc
):
g
=
slinalg
.
Solve
(
lower
=
lower
,
b_ndim
=
1
)(
A
,
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
compare_numba_and_py
(
g_fg
,
[
i
.
tag
.
test_value
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"x, exc"
,
"x, exc"
,
[
[
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
bbe663d9
import
re
import
re
from
functools
import
partial
from
typing
import
Literal
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
numpy.testing
import
assert_allclose
from
scipy
import
linalg
as
scipy_linalg
import
pytensor
import
pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph
import
FunctionGraph
from
tests
import
unittest_tools
as
utt
from
tests.link.numba.test_basic
import
compare_numba_and_py
from
tests.link.numba.test_basic
import
compare_numba_and_py
numba
=
pytest
.
importorskip
(
"numba"
)
numba
=
pytest
.
importorskip
(
"numba"
)
ATOL
=
0
if
config
.
floatX
.
endswith
(
"64"
)
else
1e-6
floatX
=
pytensor
.
config
.
floatX
RTOL
=
1e-7
if
config
.
floatX
.
endswith
(
"64"
)
else
1e-6
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
...
@@ -27,8 +31,8 @@ def transpose_func(x, trans):
...
@@ -27,8 +31,8 @@ def transpose_func(x, trans):
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"b_
func, b_siz
e"
,
"b_
shap
e"
,
[(
pt
.
matrix
,
(
5
,
1
)),
(
pt
.
matrix
,
(
5
,
5
)),
(
pt
.
vector
,
(
5
,)
)],
[(
5
,
1
),
(
5
,
5
),
(
5
,
)],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
],
)
)
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
[
"lower=True"
,
"lower=False"
])
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
[
"lower=True"
,
"lower=False"
])
...
@@ -36,50 +40,88 @@ def transpose_func(x, trans):
...
@@ -36,50 +40,88 @@ def transpose_func(x, trans):
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"unit_diag"
,
[
True
,
False
],
ids
=
[
"unit_diag=True"
,
"unit_diag=False"
]
"unit_diag"
,
[
True
,
False
],
ids
=
[
"unit_diag=True"
,
"unit_diag=False"
]
)
)
@pytest.mark.parametrize
(
"complex"
,
[
True
,
False
],
ids
=
[
"complex"
,
"real"
])
@pytest.mark.parametrize
(
"
is_
complex"
,
[
True
,
False
],
ids
=
[
"complex"
,
"real"
])
@pytest.mark.filterwarnings
(
@pytest.mark.filterwarnings
(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
)
)
def
test_solve_triangular
(
b_
func
,
b_size
,
lower
,
trans
,
unit_diag
,
complex
):
def
test_solve_triangular
(
b_
shape
:
tuple
[
int
],
lower
,
trans
,
unit_diag
,
is_
complex
):
if
complex
:
if
is_
complex
:
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
# why?
# why?
pytest
.
skip
(
"Complex inputs currently not supported to solve_triangular"
)
pytest
.
skip
(
"Complex inputs currently not supported to solve_triangular"
)
complex_dtype
=
"complex64"
if
config
.
floatX
.
endswith
(
"32"
)
else
"complex128"
complex_dtype
=
"complex64"
if
floatX
.
endswith
(
"32"
)
else
"complex128"
dtype
=
complex_dtype
if
complex
else
config
.
floatX
dtype
=
complex_dtype
if
is_complex
else
floatX
A
=
pt
.
matrix
(
"A"
,
dtype
=
dtype
)
A
=
pt
.
matrix
(
"A"
,
dtype
=
dtype
)
b
=
b_func
(
"b"
,
dtype
=
dtype
)
b
=
pt
.
tensor
(
"b"
,
shape
=
b_shape
,
dtype
=
dtype
)
def
A_func
(
x
):
x
=
x
@
x
.
conj
()
.
T
x_tri
=
scipy_linalg
.
cholesky
(
x
,
lower
=
lower
)
.
astype
(
dtype
)
if
unit_diag
:
x_tri
[
np
.
diag_indices_from
(
x_tri
)]
=
1.0
return
x_tri
.
astype
(
dtype
)
X
=
pt
.
linalg
.
solve_triangular
(
solve_op
=
partial
(
A
,
b
,
lower
=
lower
,
trans
=
trans
,
unit_diagonal
=
unit_diag
pt
.
linalg
.
solve_triangular
,
lower
=
lower
,
trans
=
trans
,
unit_diagonal
=
unit_diag
)
)
X
=
solve_op
(
A
,
b
)
f
=
pytensor
.
function
([
A
,
b
],
X
,
mode
=
"NUMBA"
)
f
=
pytensor
.
function
([
A
,
b
],
X
,
mode
=
"NUMBA"
)
A_val
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
A_val
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
b
=
np
.
random
.
normal
(
size
=
b_siz
e
)
b
_val
=
np
.
random
.
normal
(
size
=
b_shap
e
)
if
complex
:
if
is_
complex
:
A_val
=
A_val
+
np
.
random
.
normal
(
size
=
(
5
,
5
))
*
1
j
A_val
=
A_val
+
np
.
random
.
normal
(
size
=
(
5
,
5
))
*
1
j
b
=
b
+
np
.
random
.
normal
(
size
=
b_size
)
*
1
j
b_val
=
b_val
+
np
.
random
.
normal
(
size
=
b_shape
)
*
1
j
A_sym
=
A_val
@
A_val
.
conj
()
.
T
X_np
=
f
(
A_func
(
A_val
.
copy
()),
b_val
.
copy
())
test_input
=
transpose_func
(
A_func
(
A_val
.
copy
()),
trans
)
ATOL
=
1e-8
if
floatX
.
endswith
(
"64"
)
else
1e-4
RTOL
=
1e-8
if
floatX
.
endswith
(
"64"
)
else
1e-4
np
.
testing
.
assert_allclose
(
test_input
@
X_np
,
b_val
,
atol
=
ATOL
,
rtol
=
RTOL
)
compare_numba_and_py
(
f
.
maker
.
fgraph
,
[
A_func
(
A_val
.
copy
()),
b_val
.
copy
()])
@pytest.mark.parametrize
(
"lower, unit_diag, trans"
,
[(
True
,
True
,
True
),
(
False
,
False
,
False
)],
ids
=
[
"lower_unit_trans"
,
"defaults"
],
)
def
test_solve_triangular_grad
(
lower
,
unit_diag
,
trans
):
A_val
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
b_val
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
# utt.verify_grad uses small perturbations to the input matrix to calculate the finite difference gradient. When
# a non-triangular matrix is passed to scipy.linalg.solve_triangular, no error is raise, but the result will be
# wrong, resulting in wrong gradients. As a result, it is necessary to add a mapping from the space of all matrices
# to the space of triangular matrices, and test the gradient of that entire graph.
def
A_func_pt
(
x
):
x
=
x
@
x
.
conj
()
.
T
x_tri
=
pt
.
linalg
.
cholesky
(
x
,
lower
=
lower
)
.
astype
(
floatX
)
A_tri
=
np
.
linalg
.
cholesky
(
A_sym
)
.
astype
(
dtype
)
if
unit_diag
:
if
unit_diag
:
adj_mat
=
np
.
ones
((
5
,
5
))
n
=
A_val
.
shape
[
0
]
adj_mat
[
np
.
diag_indices
(
5
)]
=
1
/
np
.
diagonal
(
A_tri
)
x_tri
=
x_tri
[
np
.
diag_indices
(
n
)]
.
set
(
1.0
)
A_tri
=
A_tri
*
adj_mat
A_tri
=
A_tri
.
astype
(
dtype
)
return
transpose_func
(
x_tri
.
astype
(
floatX
),
trans
)
b
=
b
.
astype
(
dtype
)
if
not
lower
:
solve_op
=
partial
(
A_tri
=
A_tri
.
T
pt
.
linalg
.
solve_triangular
,
lower
=
lower
,
trans
=
trans
,
unit_diagonal
=
unit_diag
)
X_np
=
f
(
A_tri
,
b
)
utt
.
verify_grad
(
np
.
testing
.
assert_allclose
(
lambda
A
,
b
:
solve_op
(
A_func_pt
(
A
),
b
),
transpose_func
(
A_tri
,
trans
)
@
X_np
,
b
,
atol
=
ATOL
,
rtol
=
RTOL
[
A_val
.
copy
(),
b_val
.
copy
()],
mode
=
"NUMBA"
,
)
)
...
@@ -93,11 +135,11 @@ def test_solve_triangular_raises_on_nan_inf(value):
...
@@ -93,11 +135,11 @@ def test_solve_triangular_raises_on_nan_inf(value):
X
=
pt
.
linalg
.
solve_triangular
(
A
,
b
,
check_finite
=
True
)
X
=
pt
.
linalg
.
solve_triangular
(
A
,
b
,
check_finite
=
True
)
f
=
pytensor
.
function
([
A
,
b
],
X
,
mode
=
"NUMBA"
)
f
=
pytensor
.
function
([
A
,
b
],
X
,
mode
=
"NUMBA"
)
A_val
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
A_val
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
A_sym
=
A_val
@
A_val
.
conj
()
.
T
A_sym
=
A_val
@
A_val
.
conj
()
.
T
A_tri
=
np
.
linalg
.
cholesky
(
A_sym
)
.
astype
(
config
.
floatX
)
A_tri
=
np
.
linalg
.
cholesky
(
A_sym
)
.
astype
(
floatX
)
b
=
np
.
full
((
5
,
1
),
value
)
b
=
np
.
full
((
5
,
1
),
value
)
.
astype
(
floatX
)
with
pytest
.
raises
(
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
np
.
linalg
.
LinAlgError
,
...
@@ -119,19 +161,19 @@ def test_numba_Cholesky(lower, trans):
...
@@ -119,19 +161,19 @@ def test_numba_Cholesky(lower, trans):
fg
=
FunctionGraph
(
outputs
=
[
chol
])
fg
=
FunctionGraph
(
outputs
=
[
chol
])
x
=
np
.
array
([
0.1
,
0.2
,
0.3
])
x
=
np
.
array
([
0.1
,
0.2
,
0.3
])
.
astype
(
floatX
)
val
=
np
.
eye
(
3
)
+
x
[
None
,
:]
*
x
[:,
None
]
val
=
np
.
eye
(
3
)
.
astype
(
floatX
)
+
x
[
None
,
:]
*
x
[:,
None
]
compare_numba_and_py
(
fg
,
[
val
])
compare_numba_and_py
(
fg
,
[
val
])
def
test_numba_Cholesky_raises_on_nan_input
():
def
test_numba_Cholesky_raises_on_nan_input
():
test_value
=
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
test_value
=
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
test_value
[
0
,
0
]
=
np
.
nan
test_value
[
0
,
0
]
=
np
.
nan
x
=
pt
.
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
3
,
3
))
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
x
=
x
.
T
.
dot
(
x
)
x
=
x
.
T
.
dot
(
x
)
g
=
pt
.
linalg
.
cholesky
(
x
,
check_finite
=
True
)
g
=
pt
.
linalg
.
cholesky
(
x
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"Non-numeric values"
):
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
,
match
=
r"Non-numeric values"
):
...
@@ -140,9 +182,9 @@ def test_numba_Cholesky_raises_on_nan_input():
...
@@ -140,9 +182,9 @@ def test_numba_Cholesky_raises_on_nan_input():
@pytest.mark.parametrize
(
"on_error"
,
[
"nan"
,
"raise"
])
@pytest.mark.parametrize
(
"on_error"
,
[
"nan"
,
"raise"
])
def
test_numba_Cholesky_raise_on
(
on_error
):
def
test_numba_Cholesky_raise_on
(
on_error
):
test_value
=
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
test_value
=
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
x
=
pt
.
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
3
,
3
))
x
=
pt
.
tensor
(
dtype
=
floatX
,
shape
=
(
3
,
3
))
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
g
=
pt
.
linalg
.
cholesky
(
x
,
on_error
=
on_error
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
f
=
pytensor
.
function
([
x
],
g
,
mode
=
"NUMBA"
)
...
@@ -155,6 +197,16 @@ def test_numba_Cholesky_raise_on(on_error):
...
@@ -155,6 +197,16 @@ def test_numba_Cholesky_raise_on(on_error):
assert
np
.
all
(
np
.
isnan
(
f
(
test_value
)))
assert
np
.
all
(
np
.
isnan
(
f
(
test_value
)))
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
[
"lower=True"
,
"lower=False"
])
def
test_numba_Cholesky_grad
(
lower
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
L
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
X
=
L
@
L
.
T
chol_op
=
partial
(
pt
.
linalg
.
cholesky
,
lower
=
lower
)
utt
.
verify_grad
(
chol_op
,
[
X
],
mode
=
"NUMBA"
)
def
test_block_diag
():
def
test_block_diag
():
A
=
pt
.
matrix
(
"A"
)
A
=
pt
.
matrix
(
"A"
)
B
=
pt
.
matrix
(
"B"
)
B
=
pt
.
matrix
(
"B"
)
...
@@ -162,9 +214,242 @@ def test_block_diag():
...
@@ -162,9 +214,242 @@ def test_block_diag():
D
=
pt
.
matrix
(
"D"
)
D
=
pt
.
matrix
(
"D"
)
X
=
pt
.
linalg
.
block_diag
(
A
,
B
,
C
,
D
)
X
=
pt
.
linalg
.
block_diag
(
A
,
B
,
C
,
D
)
A_val
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
A_val
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
B_val
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
B_val
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
C_val
=
np
.
random
.
normal
(
size
=
(
2
,
2
))
C_val
=
np
.
random
.
normal
(
size
=
(
2
,
2
))
.
astype
(
floatX
)
D_val
=
np
.
random
.
normal
(
size
=
(
4
,
4
))
D_val
=
np
.
random
.
normal
(
size
=
(
4
,
4
))
.
astype
(
floatX
)
out_fg
=
pytensor
.
graph
.
FunctionGraph
([
A
,
B
,
C
,
D
],
[
X
])
out_fg
=
pytensor
.
graph
.
FunctionGraph
([
A
,
B
,
C
,
D
],
[
X
])
compare_numba_and_py
(
out_fg
,
[
A_val
,
B_val
,
C_val
,
D_val
])
compare_numba_and_py
(
out_fg
,
[
A_val
,
B_val
,
C_val
,
D_val
])
def
test_lamch
():
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.slinalg
import
_xlamch
@numba.njit
()
def
xlamch
(
kind
):
return
_xlamch
(
kind
)
lamch
=
get_lapack_funcs
(
"lamch"
,
(
np
.
array
([
0.0
],
dtype
=
floatX
),))
np
.
testing
.
assert_allclose
(
xlamch
(
"E"
),
lamch
(
"E"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"S"
),
lamch
(
"S"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"P"
),
lamch
(
"P"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"B"
),
lamch
(
"B"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"R"
),
lamch
(
"R"
))
np
.
testing
.
assert_allclose
(
xlamch
(
"M"
),
lamch
(
"M"
))
@pytest.mark.parametrize
(
"ord_numba, ord_scipy"
,
[(
"F"
,
"fro"
),
(
"1"
,
1
),
(
"I"
,
np
.
inf
)]
)
def
test_xlange
(
ord_numba
,
ord_scipy
):
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.slinalg
import
_xlange
@numba.njit
()
def
xlange
(
x
,
ord
):
return
_xlange
(
x
,
ord
)
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
np
.
testing
.
assert_allclose
(
xlange
(
x
,
ord_numba
),
linalg
.
norm
(
x
,
ord_scipy
))
@pytest.mark.parametrize
(
"ord_numba, ord_scipy"
,
[(
"1"
,
1
),
(
"I"
,
np
.
inf
)])
def
test_xgecon
(
ord_numba
,
ord_scipy
):
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
from
scipy.linalg
import
get_lapack_funcs
from
pytensor.link.numba.dispatch.slinalg
import
_xgecon
,
_xlange
@numba.njit
()
def
gecon
(
x
,
norm
):
anorm
=
_xlange
(
x
,
norm
)
cond
,
info
=
_xgecon
(
x
,
anorm
,
norm
)
return
cond
,
info
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
rcond
,
info
=
gecon
(
x
,
norm
=
ord_numba
)
# Test against direct call to the underlying LAPACK functions
# Solution does **not** agree with 1 / np.linalg.cond(x) !
lange
,
gecon
=
get_lapack_funcs
((
"lange"
,
"gecon"
),
(
x
,))
norm
=
lange
(
ord_numba
,
x
)
rcond2
,
_
=
gecon
(
x
,
norm
,
norm
=
ord_numba
)
assert
info
==
0
np
.
testing
.
assert_allclose
(
rcond
,
rcond2
)
@pytest.mark.parametrize
(
"overwrite_a"
,
[
True
,
False
])
def
test_getrf
(
overwrite_a
):
from
scipy.linalg
import
lu_factor
from
pytensor.link.numba.dispatch.slinalg
import
_getrf
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor
@numba.njit
()
def
getrf
(
x
,
overwrite_a
):
return
_getrf
(
x
,
overwrite_a
=
overwrite_a
)
x
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
x
=
np
.
asfortranarray
(
x
)
# x needs to be fortran-contiguous going into getrf for the overwrite option to work
lu
,
ipiv
=
lu_factor
(
x
,
overwrite_a
=
False
)
LU
,
IPIV
,
info
=
getrf
(
x
,
overwrite_a
=
overwrite_a
)
assert
info
==
0
assert_allclose
(
LU
,
lu
)
if
overwrite_a
:
assert_allclose
(
x
,
LU
)
# TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing
# this, though.
assert_allclose
(
IPIV
-
1
,
ipiv
)
@pytest.mark.parametrize
(
"trans"
,
[
0
,
1
])
@pytest.mark.parametrize
(
"overwrite_a"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"overwrite_b"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,),
(
5
,
3
)],
ids
=
[
"b_1d"
,
"b_2d"
])
def
test_getrs
(
trans
,
overwrite_a
,
overwrite_b
,
b_shape
):
from
scipy.linalg
import
lu_factor
from
scipy.linalg
import
lu_solve
as
sp_lu_solve
from
pytensor.link.numba.dispatch.slinalg
import
_getrf
,
_getrs
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor
@numba.njit
()
def
lu_solve
(
a
,
b
,
trans
,
overwrite_a
,
overwrite_b
):
lu
,
ipiv
,
info
=
_getrf
(
a
,
overwrite_a
=
overwrite_a
)
x
,
info
=
_getrs
(
lu
,
b
,
ipiv
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
return
x
,
lu
,
info
a
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
b
=
np
.
random
.
normal
(
size
=
b_shape
)
.
astype
(
floatX
)
# inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work
a
=
np
.
asfortranarray
(
a
)
b
=
np
.
asfortranarray
(
b
)
lu_and_piv
=
lu_factor
(
a
,
overwrite_a
=
False
)
x_sp
=
sp_lu_solve
(
lu_and_piv
,
b
,
trans
,
overwrite_b
=
False
)
x
,
lu
,
info
=
lu_solve
(
a
,
b
,
trans
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
)
assert
info
==
0
if
overwrite_a
:
assert_allclose
(
a
,
lu
)
if
overwrite_b
:
assert_allclose
(
b
,
x
)
assert_allclose
(
x
,
x_sp
)
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,
1
),
(
5
,
5
),
(
5
,)],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
],
)
@pytest.mark.parametrize
(
"assume_a"
,
[
"gen"
,
"sym"
,
"pos"
],
ids
=
str
)
@pytest.mark.filterwarnings
(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
)
def
test_solve
(
b_shape
:
tuple
[
int
],
assume_a
:
Literal
[
"gen"
,
"sym"
,
"pos"
]):
A
=
pt
.
matrix
(
"A"
,
dtype
=
floatX
)
b
=
pt
.
tensor
(
"b"
,
shape
=
b_shape
,
dtype
=
floatX
)
A_val
=
np
.
asfortranarray
(
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
))
b_val
=
np
.
asfortranarray
(
np
.
random
.
normal
(
size
=
b_shape
)
.
astype
(
floatX
))
def
A_func
(
x
):
if
assume_a
==
"pos"
:
x
=
x
@
x
.
T
elif
assume_a
==
"sym"
:
x
=
(
x
+
x
.
T
)
/
2
return
x
X
=
pt
.
linalg
.
solve
(
A_func
(
A
),
b
,
assume_a
=
assume_a
,
b_ndim
=
len
(
b_shape
),
)
f
=
pytensor
.
function
(
[
pytensor
.
In
(
A
,
mutable
=
True
),
pytensor
.
In
(
b
,
mutable
=
True
)],
X
,
mode
=
"NUMBA"
)
op
=
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
compare_numba_and_py
(([
A
,
b
],
[
X
]),
inputs
=
[
A_val
,
b_val
],
inplace
=
True
)
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
A_val_copy
=
A_val
.
copy
()
b_val_copy
=
b_val
.
copy
()
X_np
=
f
(
A_val
,
b_val
)
# overwrite_b is preferred when both inputs can be destroyed
assert
op
.
destroy_map
==
{
0
:
[
1
]}
# Confirm inputs were destroyed by checking against the copies
assert
(
A_val
==
A_val_copy
)
.
all
()
==
(
op
.
destroy_map
.
get
(
0
,
None
)
!=
[
0
])
assert
(
b_val
==
b_val_copy
)
.
all
()
==
(
op
.
destroy_map
.
get
(
0
,
None
)
!=
[
1
])
ATOL
=
1e-8
if
floatX
.
endswith
(
"64"
)
else
1e-4
RTOL
=
1e-8
if
floatX
.
endswith
(
"64"
)
else
1e-4
# Confirm b_val is used to store to solution
np
.
testing
.
assert_allclose
(
X_np
,
b_val
,
atol
=
ATOL
,
rtol
=
RTOL
)
assert
not
np
.
allclose
(
b_val
,
b_val_copy
)
# Test that the result is numerically correct. Need to use the unmodified copy
np
.
testing
.
assert_allclose
(
A_func
(
A_val_copy
)
@
X_np
,
b_val_copy
,
atol
=
ATOL
,
rtol
=
RTOL
)
# See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
utt
.
verify_grad
(
lambda
A
,
b
:
pt
.
linalg
.
solve
(
A_func
(
A
),
b
,
lower
=
False
,
assume_a
=
assume_a
,
b_ndim
=
len
(
b_shape
)
),
[
A_val_copy
,
b_val_copy
],
mode
=
"NUMBA"
,
)
@pytest.mark.parametrize
(
"b_func, b_size"
,
[(
pt
.
matrix
,
(
5
,
1
)),
(
pt
.
matrix
,
(
5
,
5
)),
(
pt
.
vector
,
(
5
,))],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
],
)
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"lower = {x}"
)
def
test_cho_solve
(
b_func
,
b_size
,
lower
):
A
=
pt
.
matrix
(
"A"
,
dtype
=
floatX
)
b
=
b_func
(
"b"
,
dtype
=
floatX
)
C
=
pt
.
linalg
.
cholesky
(
A
,
lower
=
lower
)
X
=
pt
.
linalg
.
cho_solve
((
C
,
lower
),
b
)
f
=
pytensor
.
function
([
A
,
b
],
X
,
mode
=
"NUMBA"
)
A
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
A
=
A
@
A
.
conj
()
.
T
b
=
np
.
random
.
normal
(
size
=
b_size
)
b
=
b
.
astype
(
floatX
)
X_np
=
f
(
A
,
b
)
ATOL
=
1e-8
if
floatX
.
endswith
(
"64"
)
else
1e-4
RTOL
=
1e-8
if
floatX
.
endswith
(
"64"
)
else
1e-4
np
.
testing
.
assert_allclose
(
A
@
X_np
,
b
,
atol
=
ATOL
,
rtol
=
RTOL
)
tests/tensor/test_slinalg.py
浏览文件 @
bbe663d9
...
@@ -209,12 +209,12 @@ class TestSolveBase:
...
@@ -209,12 +209,12 @@ class TestSolveBase:
)
)
class
TestSolve
(
utt
.
InferShapeTester
):
def
test_solve_raises_on_invalid_A
():
def
test__init__
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"is not a recognized matrix structure"
):
with
pytest
.
raises
(
ValueError
)
as
excinfo
:
Solve
(
assume_a
=
"test"
,
b_ndim
=
2
)
Solve
(
assume_a
=
"test"
,
b_ndim
=
2
)
assert
"is not a recognized matrix structure"
in
str
(
excinfo
.
value
)
class
TestSolve
(
utt
.
InferShapeTester
):
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,
1
),
(
5
,)])
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,
1
),
(
5
,)])
def
test_infer_shape
(
self
,
b_shape
):
def
test_infer_shape
(
self
,
b_shape
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
...
@@ -232,64 +232,78 @@ class TestSolve(utt.InferShapeTester):
...
@@ -232,64 +232,78 @@ class TestSolve(utt.InferShapeTester):
warn
=
False
,
warn
=
False
,
)
)
def
test_correctness
(
self
):
@pytest.mark.parametrize
(
"b_size"
,
[(
5
,
1
),
(
5
,
5
),
(
5
,)],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
]
)
@pytest.mark.parametrize
(
"assume_a"
,
[
"gen"
,
"sym"
,
"pos"
],
ids
=
str
)
def
test_solve_correctness
(
self
,
b_size
:
tuple
[
int
],
assume_a
:
str
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
matrix
()
A
=
pt
.
tensor
(
"A"
,
shape
=
(
5
,
5
))
b
=
matrix
()
b
=
pt
.
tensor
(
"b"
,
shape
=
b_size
)
y
=
solve
(
A
,
b
)
gen_solve_func
=
pytensor
.
function
([
A
,
b
],
y
)
b_val
=
np
.
asarray
(
rng
.
random
((
5
,
1
)),
dtype
=
config
.
floatX
)
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
b_size
)
.
astype
(
config
.
floatX
)
A_val
=
np
.
asarray
(
rng
.
random
((
5
,
5
)),
dtype
=
config
.
floatX
)
solve_op
=
functools
.
partial
(
solve
,
assume_a
=
assume_a
,
b_ndim
=
len
(
b_size
))
A_val
=
np
.
dot
(
A_val
.
transpose
(),
A_val
)
np
.
testing
.
assert_allclose
(
def
A_func
(
x
):
scipy
.
linalg
.
solve
(
A_val
,
b_val
,
assume_a
=
"gen"
),
if
assume_a
==
"pos"
:
gen_solve_func
(
A_val
,
b_val
),
return
x
@
x
.
T
)
elif
assume_a
==
"sym"
:
return
(
x
+
x
.
T
)
/
2
else
:
return
x
solve_input_val
=
A_func
(
A_val
)
y
=
solve_op
(
A_func
(
A
),
b
)
solve_func
=
pytensor
.
function
([
A
,
b
],
y
)
X_np
=
solve_func
(
A_val
.
copy
(),
b_val
.
copy
())
ATOL
=
1e-8
if
config
.
floatX
.
endswith
(
"64"
)
else
1e-4
RTOL
=
1e-8
if
config
.
floatX
.
endswith
(
"64"
)
else
1e-4
A_undef
=
np
.
array
(
[
[
1
,
0
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
1
,
0
],
],
dtype
=
config
.
floatX
,
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
scipy
.
linalg
.
solve
(
A_undef
,
b_val
),
gen_solve_func
(
A_undef
,
b_val
)
scipy
.
linalg
.
solve
(
solve_input_val
,
b_val
,
assume_a
=
assume_a
),
X_np
,
atol
=
ATOL
,
rtol
=
RTOL
,
)
)
np
.
testing
.
assert_allclose
(
A_func
(
A_val
)
@
X_np
,
b_val
,
atol
=
ATOL
,
rtol
=
RTOL
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"m, n, assume_a, lower"
,
"b_size"
,
[(
5
,
1
),
(
5
,
5
),
(
5
,)],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
]
[
)
(
5
,
None
,
"gen"
,
False
),
@pytest.mark.parametrize
(
"assume_a"
,
[
"gen"
,
"sym"
,
"pos"
],
ids
=
str
)
(
5
,
None
,
"gen"
,
True
),
@pytest.mark.skipif
(
(
4
,
2
,
"gen"
,
False
),
config
.
floatX
==
"float32"
,
reason
=
"Gradients not numerically stable in float32"
(
4
,
2
,
"gen"
,
True
),
],
)
)
def
test_solve_grad
(
self
,
m
,
n
,
assume_a
,
lowe
r
):
def
test_solve_grad
ient
(
self
,
b_size
:
tuple
[
int
],
assume_a
:
st
r
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
# Ensure diagonal elements of `A` are relatively large to avoid
eps
=
2e-8
if
config
.
floatX
==
"float64"
else
None
# numerical precision issues
A_val
=
(
rng
.
normal
(
size
=
(
m
,
m
))
*
0.5
+
np
.
eye
(
m
))
.
astype
(
config
.
floatX
)
if
n
is
None
:
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
m
)
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
b_size
)
.
astype
(
config
.
floatX
)
def
A_func
(
x
):
if
assume_a
==
"pos"
:
return
x
@
x
.
T
elif
assume_a
==
"sym"
:
return
(
x
+
x
.
T
)
/
2
else
:
else
:
b_val
=
rng
.
normal
(
size
=
(
m
,
n
))
.
astype
(
config
.
floatX
)
return
x
eps
=
None
solve_op
=
functools
.
partial
(
solve
,
assume_a
=
assume_a
,
b_ndim
=
len
(
b_size
))
if
config
.
floatX
==
"float64"
:
eps
=
2e-8
solve_op
=
Solve
(
assume_a
=
assume_a
,
lower
=
lower
,
b_ndim
=
1
if
n
is
None
else
2
)
# To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
utt
.
verify_grad
(
solve_op
,
[
A_val
,
b_val
],
3
,
rng
,
eps
=
eps
)
# (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,
# the random perturbations used by verify_grad will result in invalid input matrices, and
# LAPACK will silently do the wrong thing, making the gradients wrong
utt
.
verify_grad
(
lambda
A
,
b
:
solve_op
(
A_func
(
A
),
b
),
[
A_val
,
b_val
],
3
,
rng
,
eps
=
eps
)
class
TestSolveTriangular
(
utt
.
InferShapeTester
):
class
TestSolveTriangular
(
utt
.
InferShapeTester
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论