Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
617964ff
Unverified
提交
617964ff
authored
7月 15, 2025
作者:
Jesse Grabowski
提交者:
GitHub
7月 15, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor and update QR Op (#1518)
* Refactor QR * Update JAX QR dispatch * Update Torch QR dispatch * Update numba QR dispatch
上级
5024d54e
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
1703 行增加
和
424 行删除
+1703
-424
nlinalg.py
pytensor/link/jax/dispatch/nlinalg.py
+0
-11
slinalg.py
pytensor/link/jax/dispatch/slinalg.py
+11
-0
_LAPACK.py
pytensor/link/numba/dispatch/linalg/_LAPACK.py
+87
-1
qr.py
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
+880
-0
nlinalg.py
pytensor/link/numba/dispatch/nlinalg.py
+0
-36
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+104
-1
__init__.py
pytensor/link/pytorch/dispatch/__init__.py
+1
-0
nlinalg.py
pytensor/link/pytorch/dispatch/nlinalg.py
+0
-16
slinalg.py
pytensor/link/pytorch/dispatch/slinalg.py
+23
-0
nlinalg.py
pytensor/tensor/nlinalg.py
+0
-171
slinalg.py
pytensor/tensor/slinalg.py
+377
-3
test_nlinalg.py
tests/link/jax/test_nlinalg.py
+0
-6
test_slinalg.py
tests/link/jax/test_slinalg.py
+12
-0
test_nlinalg.py
tests/link/numba/test_nlinalg.py
+0
-54
test_slinalg.py
tests/link/numba/test_slinalg.py
+68
-0
conftest.py
tests/link/pytorch/conftest.py
+16
-0
test_nlinalg.py
tests/link/pytorch/test_nlinalg.py
+0
-27
test_slinalg.py
tests/link/pytorch/test_slinalg.py
+20
-0
test_nlinalg.py
tests/tensor/test_nlinalg.py
+0
-98
test_slinalg.py
tests/tensor/test_slinalg.py
+104
-0
没有找到文件。
pytensor/link/jax/dispatch/nlinalg.py
浏览文件 @
617964ff
...
@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
...
@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
KroneckerProduct
,
KroneckerProduct
,
MatrixInverse
,
MatrixInverse
,
MatrixPinv
,
MatrixPinv
,
QRFull
,
SLogDet
,
SLogDet
,
)
)
...
@@ -67,16 +66,6 @@ def jax_funcify_MatrixInverse(op, **kwargs):
...
@@ -67,16 +66,6 @@ def jax_funcify_MatrixInverse(op, **kwargs):
return
matrix_inverse
return
matrix_inverse
@jax_funcify.register
(
QRFull
)
def
jax_funcify_QRFull
(
op
,
**
kwargs
):
mode
=
op
.
mode
def
qr_full
(
x
,
mode
=
mode
):
return
jnp
.
linalg
.
qr
(
x
,
mode
=
mode
)
return
qr_full
@jax_funcify.register
(
MatrixPinv
)
@jax_funcify.register
(
MatrixPinv
)
def
jax_funcify_Pinv
(
op
,
**
kwargs
):
def
jax_funcify_Pinv
(
op
,
**
kwargs
):
def
pinv
(
x
):
def
pinv
(
x
):
...
...
pytensor/link/jax/dispatch/slinalg.py
浏览文件 @
617964ff
...
@@ -5,6 +5,7 @@ import jax
...
@@ -5,6 +5,7 @@ import jax
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
LU
,
LU
,
QR
,
BlockDiagonal
,
BlockDiagonal
,
Cholesky
,
Cholesky
,
CholeskySolve
,
CholeskySolve
,
...
@@ -168,3 +169,13 @@ def jax_funcify_ChoSolve(op, **kwargs):
...
@@ -168,3 +169,13 @@ def jax_funcify_ChoSolve(op, **kwargs):
)
)
return
cho_solve
return
cho_solve
@jax_funcify.register
(
QR
)
def
jax_funcify_QR
(
op
,
**
kwargs
):
mode
=
op
.
mode
def
qr
(
x
,
mode
=
mode
):
return
jax
.
scipy
.
linalg
.
qr
(
x
,
mode
=
mode
)
return
qr
pytensor/link/numba/dispatch/linalg/_LAPACK.py
浏览文件 @
617964ff
...
@@ -283,7 +283,6 @@ class _LAPACK:
...
@@ -283,7 +283,6 @@ class _LAPACK:
Called by scipy.linalg.lu_solve
Called by scipy.linalg.lu_solve
"""
"""
...
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"getrs"
)
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"getrs"
)
functype
=
ctypes
.
CFUNCTYPE
(
functype
=
ctypes
.
CFUNCTYPE
(
None
,
None
,
...
@@ -457,3 +456,90 @@ class _LAPACK:
...
@@ -457,3 +456,90 @@ class _LAPACK:
_ptr_int
,
# INFO
_ptr_int
,
# INFO
)
)
return
functype
(
lapack_ptr
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xgeqrf
(
cls
,
dtype
):
"""
Compute the QR factorization of a general M-by-N matrix A.
Used in QR decomposition (no pivoting).
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"geqrf"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# M
_ptr_int
,
# N
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# TAU
float_pointer
,
# WORK
_ptr_int
,
# LWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xgeqp3
(
cls
,
dtype
):
"""
Compute the QR factorization with column pivoting of a general M-by-N matrix A.
Used in QR decomposition with pivoting.
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"geqp3"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# M
_ptr_int
,
# N
float_pointer
,
# A
_ptr_int
,
# LDA
_ptr_int
,
# JPVT
float_pointer
,
# TAU
float_pointer
,
# WORK
_ptr_int
,
# LWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xorgqr
(
cls
,
dtype
):
"""
Generate the orthogonal matrix Q from a QR factorization (real types).
Used in QR decomposition to form Q.
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"orgqr"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# M
_ptr_int
,
# N
_ptr_int
,
# K
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# TAU
float_pointer
,
# WORK
_ptr_int
,
# LWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xungqr
(
cls
,
dtype
):
"""
Generate the unitary matrix Q from a QR factorization (complex types).
Used in QR decomposition to form Q for complex types.
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"ungqr"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# M
_ptr_int
,
# N
_ptr_int
,
# K
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# TAU
float_pointer
,
# WORK
_ptr_int
,
# LWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
0 → 100644
浏览文件 @
617964ff
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy.linalg
import
get_lapack_funcs
,
qr
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
def
_xgeqrf
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
"""LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
(
geqrf
,)
=
get_lapack_funcs
((
"geqrf"
,),
(
A
,))
return
geqrf
(
A
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
@overload
(
_xgeqrf
)
def
xgeqrf_impl
(
A
,
overwrite_a
,
lwork
):
ensure_lapack
()
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
geqrf
=
_LAPACK
()
.
numba_xgeqrf
(
dtype
)
def
impl
(
A
,
overwrite_a
,
lwork
):
M
=
np
.
int32
(
A
.
shape
[
0
])
N
=
np
.
int32
(
A
.
shape
[
1
])
if
overwrite_a
and
A
.
flags
.
f_contiguous
:
A_copy
=
A
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
LDA
=
val_to_int_ptr
(
M
)
TAU
=
np
.
empty
(
min
(
M
,
N
),
dtype
=
dtype
)
if
lwork
==
-
1
:
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
-
1
)
else
:
WORK
=
np
.
empty
(
lwork
if
lwork
>
0
else
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
WORK
.
size
)
INFO
=
val_to_int_ptr
(
1
)
geqrf
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
LWORK
,
INFO
,
)
return
A_copy
,
TAU
,
WORK
,
int_ptr_to_val
(
INFO
)
return
impl
def
_xgeqp3
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
"""LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
(
geqp3
,)
=
get_lapack_funcs
((
"geqp3"
,),
(
A
,))
return
geqp3
(
A
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
@overload
(
_xgeqp3
)
def
xgeqp3_impl
(
A
,
overwrite_a
,
lwork
):
ensure_lapack
()
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
geqp3
=
_LAPACK
()
.
numba_xgeqp3
(
dtype
)
def
impl
(
A
,
overwrite_a
,
lwork
):
M
=
np
.
int32
(
A
.
shape
[
0
])
N
=
np
.
int32
(
A
.
shape
[
1
])
if
overwrite_a
and
A
.
flags
.
f_contiguous
:
A_copy
=
A
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
LDA
=
val_to_int_ptr
(
M
)
JPVT
=
np
.
zeros
(
N
,
dtype
=
np
.
int32
)
TAU
=
np
.
empty
(
min
(
M
,
N
),
dtype
=
dtype
)
if
lwork
==
-
1
:
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
-
1
)
else
:
WORK
=
np
.
empty
(
lwork
if
lwork
>
0
else
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
WORK
.
size
)
INFO
=
val_to_int_ptr
(
1
)
geqp3
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
JPVT
.
ctypes
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
LWORK
,
INFO
,
)
return
A_copy
,
JPVT
,
TAU
,
WORK
,
int_ptr_to_val
(
INFO
)
return
impl
def
_xorgqr
(
A
:
np
.
ndarray
,
tau
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
"""LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
(
orgqr
,)
=
get_lapack_funcs
((
"orgqr"
,),
(
A
,))
return
orgqr
(
A
,
tau
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
@overload
(
_xorgqr
)
def
xorgqr_impl
(
A
,
tau
,
overwrite_a
,
lwork
):
ensure_lapack
()
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
orgqr
=
_LAPACK
()
.
numba_xorgqr
(
dtype
)
def
impl
(
A
,
tau
,
overwrite_a
,
lwork
):
M
=
np
.
int32
(
A
.
shape
[
0
])
N
=
np
.
int32
(
A
.
shape
[
1
])
K
=
np
.
int32
(
tau
.
shape
[
0
])
if
overwrite_a
and
A
.
flags
.
f_contiguous
:
A_copy
=
A
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
if
lwork
==
-
1
:
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
-
1
)
else
:
WORK
=
np
.
empty
(
lwork
if
lwork
>
0
else
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
WORK
.
size
)
LDA
=
val_to_int_ptr
(
M
)
INFO
=
val_to_int_ptr
(
1
)
orgqr
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
val_to_int_ptr
(
K
),
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
tau
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
LWORK
,
INFO
,
)
return
A_copy
,
WORK
,
int_ptr_to_val
(
INFO
)
return
impl
def
_xungqr
(
A
:
np
.
ndarray
,
tau
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
"""LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
(
ungqr
,)
=
get_lapack_funcs
((
"ungqr"
,),
(
A
,))
return
ungqr
(
A
,
tau
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
@overload
(
_xungqr
)
def
xungqr_impl
(
A
,
tau
,
overwrite_a
,
lwork
):
ensure_lapack
()
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
ungqr
=
_LAPACK
()
.
numba_xungqr
(
dtype
)
def
impl
(
A
,
tau
,
overwrite_a
,
lwork
):
M
=
np
.
int32
(
A
.
shape
[
0
])
N
=
np
.
int32
(
A
.
shape
[
1
])
K
=
np
.
int32
(
tau
.
shape
[
0
])
if
overwrite_a
and
A
.
flags
.
f_contiguous
:
A_copy
=
A
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
LDA
=
val_to_int_ptr
(
M
)
if
lwork
==
-
1
:
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
-
1
)
else
:
WORK
=
np
.
empty
(
lwork
if
lwork
>
0
else
1
,
dtype
=
dtype
)
LWORK
=
val_to_int_ptr
(
WORK
.
size
)
INFO
=
val_to_int_ptr
(
1
)
ungqr
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
val_to_int_ptr
(
K
),
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
tau
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
LWORK
,
INFO
,
)
return
A_copy
,
WORK
,
int_ptr_to_val
(
INFO
)
return
impl
def
_qr_full_pivot
(
x
:
np
.
ndarray
,
mode
:
str
=
"full"
,
pivoting
:
bool
=
True
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode not "r" or "raw", and pivoting is True, resulting in a return of arrays Q, R, and
P.
"""
return
qr
(
x
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
lwork
=
lwork
,
)
def
_qr_full_no_pivot
(
x
:
np
.
ndarray
,
mode
:
str
=
"full"
,
pivoting
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode not "r" or "raw", and pivoting is False, resulting in a return of arrays Q and R.
"""
return
qr
(
x
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
lwork
=
lwork
,
)
def
_qr_r_pivot
(
x
:
np
.
ndarray
,
mode
:
str
=
"r"
,
pivoting
:
bool
=
True
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode is "r" or "raw", and pivoting is True, resulting in a return of arrays R and P.
"""
return
qr
(
x
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
lwork
=
lwork
,
)
def
_qr_r_no_pivot
(
x
:
np
.
ndarray
,
mode
:
str
=
"r"
,
pivoting
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode is "r" or "raw", and pivoting is False, resulting in a return of array R.
"""
return
qr
(
x
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
lwork
=
lwork
,
)
def
_qr_raw_no_pivot
(
x
:
np
.
ndarray
,
mode
:
str
=
"raw"
,
pivoting
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode is "raw", and pivoting is False, resulting in a return of arrays H, tau, and R.
"""
(
H
,
tau
),
R
=
qr
(
x
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
lwork
=
lwork
,
)
return
H
,
tau
,
R
def
_qr_raw_pivot
(
x
:
np
.
ndarray
,
mode
:
str
=
"raw"
,
pivoting
:
bool
=
True
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode is "raw", and pivoting is True, resulting in a return of arrays H, tau, R, and P.
"""
(
H
,
tau
),
R
,
P
=
qr
(
x
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
lwork
=
lwork
,
)
return
H
,
tau
,
R
,
P
@overload
(
_qr_full_pivot
)
def
qr_full_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
dtype
=
x
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
geqp3
=
_LAPACK
()
.
numba_xgeqp3
(
dtype
)
orgqr
=
_LAPACK
()
.
numba_xorgqr
(
dtype
)
def
impl
(
x
,
mode
=
"full"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
N
=
np
.
int32
(
x
.
shape
[
1
])
K
=
min
(
M
,
N
)
if
overwrite_a
and
x
.
flags
.
f_contiguous
:
x_copy
=
x
else
:
x_copy
=
_copy_to_fortran_order
(
x
)
LDA
=
val_to_int_ptr
(
M
)
TAU
=
np
.
empty
(
K
,
dtype
=
dtype
)
JPVT
=
np
.
zeros
(
N
,
dtype
=
np
.
int32
)
if
lwork
is
None
:
lwork
=
-
1
if
lwork
==
-
1
:
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
geqp3
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
JPVT
.
ctypes
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
-
1
),
val_to_int_ptr
(
1
),
)
lwork_val
=
int
(
WORK
.
item
())
else
:
lwork_val
=
lwork
WORK
=
np
.
empty
(
lwork_val
,
dtype
=
dtype
)
INFO
=
val_to_int_ptr
(
1
)
geqp3
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
JPVT
.
ctypes
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
lwork_val
),
INFO
,
)
JPVT
=
(
JPVT
-
1
)
.
astype
(
np
.
int32
)
if
mode
==
"full"
or
M
<
N
:
R
=
np
.
triu
(
x_copy
)
else
:
R
=
np
.
triu
(
x_copy
[:
N
,
:])
if
M
<
N
:
Q_in
=
x_copy
[:,
:
M
]
elif
M
==
N
or
mode
==
"economic"
:
Q_in
=
x_copy
else
:
# Transpose to put the matrix into Fortran order
Q_in
=
np
.
empty
((
M
,
M
),
dtype
=
dtype
)
.
T
Q_in
[:,
:
N
]
=
x_copy
if
lwork
==
-
1
:
WORKQ
=
np
.
empty
(
1
,
dtype
=
dtype
)
orgqr
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
Q_in
.
shape
[
1
]),
val_to_int_ptr
(
K
),
Q_in
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
M
),
TAU
.
view
(
w_type
)
.
ctypes
,
WORKQ
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
-
1
),
val_to_int_ptr
(
1
),
)
lwork_q
=
int
(
WORKQ
.
item
())
else
:
lwork_q
=
lwork
WORKQ
=
np
.
empty
(
lwork_q
,
dtype
=
dtype
)
INFOQ
=
val_to_int_ptr
(
1
)
orgqr
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
Q_in
.
shape
[
1
]),
val_to_int_ptr
(
K
),
Q_in
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
M
),
TAU
.
view
(
w_type
)
.
ctypes
,
WORKQ
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
lwork_q
),
INFOQ
,
)
return
Q_in
,
R
,
JPVT
return
impl
@overload
(
_qr_full_no_pivot
)
def
qr_full_no_pivot_impl
(
x
,
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
dtype
=
x
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
geqrf
=
_LAPACK
()
.
numba_xgeqrf
(
dtype
)
orgqr
=
_LAPACK
()
.
numba_xorgqr
(
dtype
)
def
impl
(
x
,
mode
=
"full"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
N
=
np
.
int32
(
x
.
shape
[
1
])
K
=
min
(
M
,
N
)
if
overwrite_a
and
x
.
flags
.
f_contiguous
:
x_copy
=
x
else
:
x_copy
=
_copy_to_fortran_order
(
x
)
LDA
=
val_to_int_ptr
(
M
)
TAU
=
np
.
empty
(
K
,
dtype
=
dtype
)
if
lwork
is
None
:
lwork
=
-
1
if
lwork
==
-
1
:
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
geqrf
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
-
1
),
val_to_int_ptr
(
1
),
)
lwork_val
=
int
(
WORK
.
item
())
else
:
lwork_val
=
lwork
WORK
=
np
.
empty
(
lwork_val
,
dtype
=
dtype
)
INFO
=
val_to_int_ptr
(
1
)
geqrf
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
lwork_val
),
INFO
,
)
if
M
<
N
or
mode
==
"full"
:
R
=
np
.
triu
(
x_copy
)
else
:
R
=
np
.
triu
(
x_copy
[:
N
,
:])
if
M
<
N
:
Q_in
=
x_copy
[:,
:
M
]
elif
M
==
N
or
mode
==
"economic"
:
Q_in
=
x_copy
else
:
# Transpose to put the matrix into Fortran order
Q_in
=
np
.
empty
((
M
,
M
),
dtype
=
dtype
)
.
T
Q_in
[:,
:
N
]
=
x_copy
if
lwork
==
-
1
:
WORKQ
=
np
.
empty
(
1
,
dtype
=
dtype
)
orgqr
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
Q_in
.
shape
[
1
]),
val_to_int_ptr
(
K
),
Q_in
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
M
),
TAU
.
view
(
w_type
)
.
ctypes
,
WORKQ
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
-
1
),
val_to_int_ptr
(
1
),
)
lwork_q
=
int
(
WORKQ
.
item
())
else
:
lwork_q
=
lwork
WORKQ
=
np
.
empty
(
lwork_q
,
dtype
=
dtype
)
INFOQ
=
val_to_int_ptr
(
1
)
orgqr
(
val_to_int_ptr
(
M
),
# M
val_to_int_ptr
(
Q_in
.
shape
[
1
]),
# N
val_to_int_ptr
(
K
),
# K
Q_in
.
view
(
w_type
)
.
ctypes
,
# A
val_to_int_ptr
(
M
),
# LDA
TAU
.
view
(
w_type
)
.
ctypes
,
# TAU
WORKQ
.
view
(
w_type
)
.
ctypes
,
# WORK
val_to_int_ptr
(
lwork_q
),
# LWORK
INFOQ
,
# INFO
)
return
Q_in
,
R
return
impl
@overload
(
_qr_r_pivot
)
def
qr_r_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
dtype
=
x
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
geqp3
=
_LAPACK
()
.
numba_xgeqp3
(
dtype
)
def
impl
(
x
,
mode
=
"r"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
N
=
np
.
int32
(
x
.
shape
[
1
])
if
overwrite_a
and
x
.
flags
.
f_contiguous
:
x_copy
=
x
else
:
x_copy
=
_copy_to_fortran_order
(
x
)
LDA
=
val_to_int_ptr
(
M
)
K
=
min
(
M
,
N
)
TAU
=
np
.
empty
(
K
,
dtype
=
dtype
)
JPVT
=
np
.
zeros
(
N
,
dtype
=
np
.
int32
)
if
lwork
is
None
:
lwork
=
-
1
if
lwork
==
-
1
:
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
geqp3
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
JPVT
.
ctypes
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
-
1
),
val_to_int_ptr
(
1
),
)
lwork_val
=
int
(
WORK
.
item
())
else
:
lwork_val
=
lwork
WORK
=
np
.
empty
(
lwork_val
,
dtype
=
dtype
)
INFO
=
val_to_int_ptr
(
1
)
geqp3
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
JPVT
.
ctypes
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
lwork_val
),
INFO
,
)
JPVT
=
(
JPVT
-
1
)
.
astype
(
np
.
int32
)
if
M
<
N
:
R
=
np
.
triu
(
x_copy
)
else
:
R
=
np
.
triu
(
x_copy
[:
N
,
:])
return
R
,
JPVT
return
impl
@overload
(
_qr_r_no_pivot
)
def
qr_r_no_pivot_impl
(
x
,
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
dtype
=
x
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
geqrf
=
_LAPACK
()
.
numba_xgeqrf
(
dtype
)
def
impl
(
x
,
mode
=
"r"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
N
=
np
.
int32
(
x
.
shape
[
1
])
if
overwrite_a
and
x
.
flags
.
f_contiguous
:
x_copy
=
x
else
:
x_copy
=
_copy_to_fortran_order
(
x
)
LDA
=
val_to_int_ptr
(
M
)
K
=
min
(
M
,
N
)
TAU
=
np
.
empty
(
K
,
dtype
=
dtype
)
if
lwork
is
None
:
lwork
=
-
1
if
lwork
==
-
1
:
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
geqrf
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
-
1
),
val_to_int_ptr
(
1
),
)
lwork_val
=
int
(
WORK
.
item
())
else
:
lwork_val
=
lwork
WORK
=
np
.
empty
(
lwork_val
,
dtype
=
dtype
)
INFO
=
val_to_int_ptr
(
1
)
geqrf
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
lwork_val
),
INFO
,
)
if
M
<
N
:
R
=
np
.
triu
(
x_copy
)
else
:
R
=
np
.
triu
(
x_copy
[:
N
,
:])
# Return a tuple with R only to match the scipy qr interface
return
(
R
,)
return
impl
@overload
(
_qr_raw_no_pivot
)
def
qr_raw_no_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
dtype
=
x
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
geqrf
=
_LAPACK
()
.
numba_xgeqrf
(
dtype
)
def
impl
(
x
,
mode
=
"raw"
,
pivoting
=
False
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
N
=
np
.
int32
(
x
.
shape
[
1
])
if
overwrite_a
and
x
.
flags
.
f_contiguous
:
x_copy
=
x
else
:
x_copy
=
_copy_to_fortran_order
(
x
)
LDA
=
val_to_int_ptr
(
M
)
K
=
min
(
M
,
N
)
TAU
=
np
.
empty
(
K
,
dtype
=
dtype
)
if
lwork
is
None
:
lwork
=
-
1
if
lwork
==
-
1
:
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
geqrf
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
-
1
),
val_to_int_ptr
(
1
),
)
lwork_val
=
int
(
WORK
.
item
())
else
:
lwork_val
=
lwork
WORK
=
np
.
empty
(
lwork_val
,
dtype
=
dtype
)
INFO
=
val_to_int_ptr
(
1
)
geqrf
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
lwork_val
),
INFO
,
)
if
M
<
N
:
R
=
np
.
triu
(
x_copy
)
else
:
R
=
np
.
triu
(
x_copy
[:
N
,
:])
return
x_copy
,
TAU
,
R
return
impl
@overload
(
_qr_raw_pivot
)
def
qr_raw_pivot_impl
(
x
,
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
):
ensure_lapack
()
dtype
=
x
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
geqp3
=
_LAPACK
()
.
numba_xgeqp3
(
dtype
)
def
impl
(
x
,
mode
=
"raw"
,
pivoting
=
True
,
overwrite_a
=
False
,
check_finite
=
False
,
lwork
=
None
,
):
M
=
np
.
int32
(
x
.
shape
[
0
])
N
=
np
.
int32
(
x
.
shape
[
1
])
if
overwrite_a
and
x
.
flags
.
f_contiguous
:
x_copy
=
x
else
:
x_copy
=
_copy_to_fortran_order
(
x
)
LDA
=
val_to_int_ptr
(
M
)
K
=
min
(
M
,
N
)
TAU
=
np
.
empty
(
K
,
dtype
=
dtype
)
JPVT
=
np
.
zeros
(
N
,
dtype
=
np
.
int32
)
if
lwork
is
None
:
lwork
=
-
1
if
lwork
==
-
1
:
WORK
=
np
.
empty
(
1
,
dtype
=
dtype
)
geqp3
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
JPVT
.
ctypes
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
-
1
),
val_to_int_ptr
(
1
),
)
lwork_val
=
int
(
WORK
.
item
())
else
:
lwork_val
=
lwork
WORK
=
np
.
empty
(
lwork_val
,
dtype
=
dtype
)
INFO
=
val_to_int_ptr
(
1
)
geqp3
(
val_to_int_ptr
(
M
),
val_to_int_ptr
(
N
),
x_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
JPVT
.
ctypes
,
TAU
.
view
(
w_type
)
.
ctypes
,
WORK
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
lwork_val
),
INFO
,
)
JPVT
=
(
JPVT
-
1
)
.
astype
(
np
.
int32
)
if
M
<
N
:
R
=
np
.
triu
(
x_copy
)
else
:
R
=
np
.
triu
(
x_copy
[:
N
,
:])
return
x_copy
,
TAU
,
R
,
JPVT
return
impl
pytensor/link/numba/dispatch/nlinalg.py
浏览文件 @
617964ff
...
@@ -16,7 +16,6 @@ from pytensor.tensor.nlinalg import (
...
@@ -16,7 +16,6 @@ from pytensor.tensor.nlinalg import (
Eigh
,
Eigh
,
MatrixInverse
,
MatrixInverse
,
MatrixPinv
,
MatrixPinv
,
QRFull
,
SLogDet
,
SLogDet
,
)
)
...
@@ -146,38 +145,3 @@ def numba_funcify_MatrixPinv(op, node, **kwargs):
...
@@ -146,38 +145,3 @@ def numba_funcify_MatrixPinv(op, node, **kwargs):
return
np
.
linalg
.
pinv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
np
.
linalg
.
pinv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
matrixpinv
return
matrixpinv
@numba_funcify.register
(
QRFull
)
def
numba_funcify_QRFull
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
if
mode
!=
"reduced"
:
warnings
.
warn
(
(
"Numba will use object mode to allow the "
"`mode` argument to `numpy.linalg.qr`."
),
UserWarning
,
)
if
len
(
node
.
outputs
)
>
1
:
ret_sig
=
numba
.
types
.
Tuple
([
get_numba_type
(
o
.
type
)
for
o
in
node
.
outputs
])
else
:
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
@numba_basic.numba_njit
def
qr_full
(
x
):
with
numba
.
objmode
(
ret
=
ret_sig
):
ret
=
np
.
linalg
.
qr
(
x
,
mode
=
mode
)
return
ret
else
:
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
qr_full
(
x
):
return
np
.
linalg
.
qr
(
inputs_cast
(
x
))
return
qr_full
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
617964ff
...
@@ -2,6 +2,7 @@ import warnings
...
@@ -2,6 +2,7 @@ import warnings
import
numpy
as
np
import
numpy
as
np
from
pytensor
import
config
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
...
@@ -11,6 +12,14 @@ from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
...
@@ -11,6 +12,14 @@ from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
_pivot_to_permutation
,
_pivot_to_permutation
,
)
)
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_lu_factor
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_lu_factor
from
pytensor.link.numba.dispatch.linalg.decomposition.qr
import
(
_qr_full_no_pivot
,
_qr_full_pivot
,
_qr_r_no_pivot
,
_qr_r_pivot
,
_qr_raw_no_pivot
,
_qr_raw_pivot
,
)
from
pytensor.link.numba.dispatch.linalg.solve.cholesky
import
_cho_solve
from
pytensor.link.numba.dispatch.linalg.solve.cholesky
import
_cho_solve
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_solve_gen
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_solve_gen
from
pytensor.link.numba.dispatch.linalg.solve.posdef
import
_solve_psd
from
pytensor.link.numba.dispatch.linalg.solve.posdef
import
_solve_psd
...
@@ -19,6 +28,7 @@ from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangul
...
@@ -19,6 +28,7 @@ from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangul
from
pytensor.link.numba.dispatch.linalg.solve.tridiagonal
import
_solve_tridiagonal
from
pytensor.link.numba.dispatch.linalg.solve.tridiagonal
import
_solve_tridiagonal
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
LU
,
LU
,
QR
,
BlockDiagonal
,
BlockDiagonal
,
Cholesky
,
Cholesky
,
CholeskySolve
,
CholeskySolve
,
...
@@ -27,7 +37,7 @@ from pytensor.tensor.slinalg import (
...
@@ -27,7 +37,7 @@ from pytensor.tensor.slinalg import (
Solve
,
Solve
,
SolveTriangular
,
SolveTriangular
,
)
)
from
pytensor.tensor.type
import
complex_dtypes
from
pytensor.tensor.type
import
complex_dtypes
,
integer_dtypes
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
=
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
=
(
...
@@ -311,3 +321,96 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
...
@@ -311,3 +321,96 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
)
)
return
cho_solve
return
cho_solve
@numba_funcify.register
(
QR
)
def
numba_funcify_QR
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
check_finite
=
op
.
check_finite
pivoting
=
op
.
pivoting
overwrite_a
=
op
.
overwrite_a
dtype
=
node
.
inputs
[
0
]
.
dtype
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
integer_input
=
dtype
in
integer_dtypes
in_dtype
=
config
.
floatX
if
integer_input
else
dtype
@numba_njit
(
cache
=
False
)
def
qr
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to qr"
)
if
integer_input
:
a
=
a
.
astype
(
in_dtype
)
if
(
mode
==
"full"
or
mode
==
"economic"
)
and
pivoting
:
Q
,
R
,
P
=
_qr_full_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
Q
,
R
,
P
elif
(
mode
==
"full"
or
mode
==
"economic"
)
and
not
pivoting
:
Q
,
R
=
_qr_full_no_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
Q
,
R
elif
mode
==
"r"
and
pivoting
:
R
,
P
=
_qr_r_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
R
,
P
elif
mode
==
"r"
and
not
pivoting
:
(
R
,)
=
_qr_r_no_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
R
elif
mode
==
"raw"
and
pivoting
:
H
,
tau
,
R
,
P
=
_qr_raw_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
H
,
tau
,
R
,
P
elif
mode
==
"raw"
and
not
pivoting
:
H
,
tau
,
R
=
_qr_raw_no_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
H
,
tau
,
R
else
:
raise
NotImplementedError
(
f
"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
)
return
qr
pytensor/link/pytorch/dispatch/__init__.py
浏览文件 @
617964ff
...
@@ -8,6 +8,7 @@ import pytensor.link.pytorch.dispatch.elemwise
...
@@ -8,6 +8,7 @@ import pytensor.link.pytorch.dispatch.elemwise
import
pytensor.link.pytorch.dispatch.math
import
pytensor.link.pytorch.dispatch.math
import
pytensor.link.pytorch.dispatch.extra_ops
import
pytensor.link.pytorch.dispatch.extra_ops
import
pytensor.link.pytorch.dispatch.nlinalg
import
pytensor.link.pytorch.dispatch.nlinalg
import
pytensor.link.pytorch.dispatch.slinalg
import
pytensor.link.pytorch.dispatch.shape
import
pytensor.link.pytorch.dispatch.shape
import
pytensor.link.pytorch.dispatch.sort
import
pytensor.link.pytorch.dispatch.sort
import
pytensor.link.pytorch.dispatch.subtensor
import
pytensor.link.pytorch.dispatch.subtensor
...
...
pytensor/link/pytorch/dispatch/nlinalg.py
浏览文件 @
617964ff
...
@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
...
@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
KroneckerProduct
,
KroneckerProduct
,
MatrixInverse
,
MatrixInverse
,
MatrixPinv
,
MatrixPinv
,
QRFull
,
SLogDet
,
SLogDet
,
)
)
...
@@ -70,21 +69,6 @@ def pytorch_funcify_MatrixInverse(op, **kwargs):
...
@@ -70,21 +69,6 @@ def pytorch_funcify_MatrixInverse(op, **kwargs):
return
matrix_inverse
return
matrix_inverse
@pytorch_funcify.register
(
QRFull
)
def
pytorch_funcify_QRFull
(
op
,
**
kwargs
):
mode
=
op
.
mode
if
mode
==
"raw"
:
raise
NotImplementedError
(
"raw mode not implemented in PyTorch"
)
def
qr_full
(
x
):
Q
,
R
=
torch
.
linalg
.
qr
(
x
,
mode
=
mode
)
if
mode
==
"r"
:
return
R
return
Q
,
R
return
qr_full
@pytorch_funcify.register
(
MatrixPinv
)
@pytorch_funcify.register
(
MatrixPinv
)
def
pytorch_funcify_Pinv
(
op
,
**
kwargs
):
def
pytorch_funcify_Pinv
(
op
,
**
kwargs
):
hermitian
=
op
.
hermitian
hermitian
=
op
.
hermitian
...
...
pytensor/link/pytorch/dispatch/slinalg.py
0 → 100644
浏览文件 @
617964ff
import
torch
from
pytensor.link.pytorch.dispatch
import
pytorch_funcify
from
pytensor.tensor.slinalg
import
QR
@pytorch_funcify.register
(
QR
)
def
pytorch_funcify_QR
(
op
,
**
kwargs
):
mode
=
op
.
mode
if
mode
==
"raw"
:
raise
NotImplementedError
(
"raw mode not implemented in PyTorch"
)
elif
mode
==
"full"
:
mode
=
"complete"
elif
mode
==
"economic"
:
mode
=
"reduced"
def
qr
(
x
):
Q
,
R
=
torch
.
linalg
.
qr
(
x
,
mode
=
mode
)
if
mode
==
"r"
:
return
R
return
Q
,
R
return
qr
pytensor/tensor/nlinalg.py
浏览文件 @
617964ff
...
@@ -5,15 +5,12 @@ from typing import Literal, cast
...
@@ -5,15 +5,12 @@ from typing import Literal, cast
import
numpy
as
np
import
numpy
as
np
import
pytensor.tensor
as
pt
from
pytensor
import
scalar
as
ps
from
pytensor
import
scalar
as
ps
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.gradient
import
DisconnectedType
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.ifelse
import
ifelse
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.raise_op
import
Assert
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
math
as
ptm
from
pytensor.tensor
import
math
as
ptm
...
@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"):
...
@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"):
return
Eigh
(
UPLO
)(
a
)
return
Eigh
(
UPLO
)(
a
)
class
QRFull
(
Op
):
"""
Full QR Decomposition.
Computes the QR decomposition of a matrix.
Factor the matrix a as qr, where q is orthonormal
and r is upper-triangular.
"""
__props__
=
(
"mode"
,)
def
__init__
(
self
,
mode
):
self
.
mode
=
mode
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
assert
x
.
ndim
==
2
,
"The input of qr function should be a matrix."
in_dtype
=
x
.
type
.
numpy_dtype
out_dtype
=
np
.
dtype
(
f
"f{in_dtype.itemsize}"
)
q
=
matrix
(
dtype
=
out_dtype
)
if
self
.
mode
!=
"raw"
:
r
=
matrix
(
dtype
=
out_dtype
)
else
:
r
=
vector
(
dtype
=
out_dtype
)
if
self
.
mode
!=
"r"
:
q
=
matrix
(
dtype
=
out_dtype
)
outputs
=
[
q
,
r
]
else
:
outputs
=
[
r
]
return
Apply
(
self
,
[
x
],
outputs
)
def
perform
(
self
,
node
,
inputs
,
outputs
):
(
x
,)
=
inputs
assert
x
.
ndim
==
2
,
"The input of qr function should be a matrix."
res
=
np
.
linalg
.
qr
(
x
,
self
.
mode
)
if
self
.
mode
!=
"r"
:
outputs
[
0
][
0
],
outputs
[
1
][
0
]
=
res
else
:
outputs
[
0
][
0
]
=
res
def
L_op
(
self
,
inputs
,
outputs
,
output_grads
):
"""
Reverse-mode gradient of the QR function.
References
----------
.. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
.. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
"""
from
pytensor.tensor.slinalg
import
solve_triangular
(
A
,)
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
inputs
)
m
,
n
=
A
.
shape
def
_H
(
x
:
ptb
.
TensorVariable
):
return
x
.
conj
()
.
mT
def
_copyltu
(
x
:
ptb
.
TensorVariable
):
return
ptb
.
tril
(
x
,
k
=
0
)
+
_H
(
ptb
.
tril
(
x
,
k
=-
1
))
if
self
.
mode
==
"raw"
:
raise
NotImplementedError
(
"Gradient of qr not implemented for mode=raw"
)
elif
self
.
mode
==
"r"
:
# We need all the components of the QR to compute the gradient of A even if we only
# use the upper triangular component in the cost function.
Q
,
R
=
qr
(
A
,
mode
=
"reduced"
)
dQ
=
Q
.
zeros_like
()
dR
=
cast
(
ptb
.
TensorVariable
,
output_grads
[
0
])
else
:
Q
,
R
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
outputs
)
if
self
.
mode
==
"complete"
:
qr_assert_op
=
Assert
(
"Gradient of qr not implemented for m x n matrices with m > n and mode=complete"
)
R
=
qr_assert_op
(
R
,
ptm
.
le
(
m
,
n
))
new_output_grads
=
[]
is_disconnected
=
[
isinstance
(
x
.
type
,
DisconnectedType
)
for
x
in
output_grads
]
if
all
(
is_disconnected
):
# This should never be reached by Pytensor
return
[
DisconnectedType
()()]
# pragma: no cover
for
disconnected
,
output_grad
,
output
in
zip
(
is_disconnected
,
output_grads
,
[
Q
,
R
],
strict
=
True
):
if
disconnected
:
new_output_grads
.
append
(
output
.
zeros_like
())
else
:
new_output_grads
.
append
(
output_grad
)
(
dQ
,
dR
)
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
new_output_grads
)
# gradient expression when m >= n
M
=
R
@
_H
(
dR
)
-
_H
(
dQ
)
@
Q
K
=
dQ
+
Q
@
_copyltu
(
M
)
A_bar_m_ge_n
=
_H
(
solve_triangular
(
R
,
_H
(
K
)))
# gradient expression when m < n
Y
=
A
[:,
m
:]
U
=
R
[:,
:
m
]
dU
,
dV
=
dR
[:,
:
m
],
dR
[:,
m
:]
dQ_Yt_dV
=
dQ
+
Y
@
_H
(
dV
)
M
=
U
@
_H
(
dU
)
-
_H
(
dQ_Yt_dV
)
@
Q
X_bar
=
_H
(
solve_triangular
(
U
,
_H
(
dQ_Yt_dV
+
Q
@
_copyltu
(
M
))))
Y_bar
=
Q
@
dV
A_bar_m_lt_n
=
pt
.
concatenate
([
X_bar
,
Y_bar
],
axis
=
1
)
return
[
ifelse
(
ptm
.
ge
(
m
,
n
),
A_bar_m_ge_n
,
A_bar_m_lt_n
)]
def
qr
(
a
,
mode
=
"reduced"
):
"""
Computes the QR decomposition of a matrix.
Factor the matrix a as qr, where q
is orthonormal and r is upper-triangular.
Parameters
----------
a : array_like, shape (M, N)
Matrix to be factored.
mode : {'reduced', 'complete', 'r', 'raw'}, optional
If K = min(M, N), then
'reduced'
returns q, r with dimensions (M, K), (K, N)
'complete'
returns q, r with dimensions (M, M), (M, N)
'r'
returns r only with dimensions (K, N)
'raw'
returns h, tau with dimensions (N, M), (K,)
Note that array h returned in 'raw' mode is
transposed for calling Fortran.
Default mode is 'reduced'
Returns
-------
q : matrix of float or complex, optional
A matrix with orthonormal columns. When mode = 'complete' the
result is an orthogonal/unitary matrix depending on whether or
not a is real/complex. The determinant may be either +/- 1 in
that case.
r : matrix of float or complex, optional
The upper-triangular matrix.
"""
return
QRFull
(
mode
)(
a
)
class
SVD
(
Op
):
class
SVD
(
Op
):
"""
"""
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
...
@@ -1291,7 +1121,6 @@ __all__ = [
...
@@ -1291,7 +1121,6 @@ __all__ = [
"det"
,
"det"
,
"eig"
,
"eig"
,
"eigh"
,
"eigh"
,
"qr"
,
"svd"
,
"svd"
,
"lstsq"
,
"lstsq"
,
"matrix_power"
,
"matrix_power"
,
...
...
pytensor/tensor/slinalg.py
浏览文件 @
617964ff
...
@@ -7,16 +7,19 @@ from typing import Literal, cast
...
@@ -7,16 +7,19 @@ from typing import Literal, cast
import
numpy
as
np
import
numpy
as
np
import
scipy.linalg
as
scipy_linalg
import
scipy.linalg
as
scipy_linalg
from
numpy.exceptions
import
ComplexWarning
from
numpy.exceptions
import
ComplexWarning
from
scipy.linalg
import
get_lapack_funcs
import
pytensor
import
pytensor
import
pytensor.tensor
as
pt
from
pytensor
import
ifelse
from
pytensor
import
tensor
as
pt
from
pytensor.gradient
import
DisconnectedType
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.tensor
import
TensorLike
,
as_tensor_variable
from
pytensor.raise_op
import
Assert
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
math
as
ptm
from
pytensor.tensor
import
math
as
ptm
from
pytensor.tensor.basic
import
diagonal
from
pytensor.tensor.basic
import
as_tensor_variable
,
diagonal
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.nlinalg
import
kron
,
matrix_dot
from
pytensor.tensor.nlinalg
import
kron
,
matrix_dot
from
pytensor.tensor.shape
import
reshape
from
pytensor.tensor.shape
import
reshape
...
@@ -1714,6 +1717,376 @@ def block_diag(*matrices: TensorVariable):
...
@@ -1714,6 +1717,376 @@ def block_diag(*matrices: TensorVariable):
return
_block_diagonal_matrix
(
*
matrices
)
return
_block_diagonal_matrix
(
*
matrices
)
class
QR
(
Op
):
"""
QR Decomposition
"""
__props__
=
(
"overwrite_a"
,
"mode"
,
"pivoting"
,
"check_finite"
,
)
def
__init__
(
self
,
mode
:
Literal
[
"full"
,
"r"
,
"economic"
,
"raw"
]
=
"full"
,
overwrite_a
:
bool
=
False
,
pivoting
:
bool
=
False
,
check_finite
:
bool
=
False
,
):
self
.
mode
=
mode
self
.
overwrite_a
=
overwrite_a
self
.
pivoting
=
pivoting
self
.
check_finite
=
check_finite
self
.
destroy_map
=
{}
if
overwrite_a
:
self
.
destroy_map
=
{
0
:
[
0
]}
match
self
.
mode
:
case
"economic"
:
self
.
gufunc_signature
=
"(m,n)->(m,k),(k,n)"
case
"full"
:
self
.
gufunc_signature
=
"(m,n)->(m,m),(m,n)"
case
"r"
:
self
.
gufunc_signature
=
"(m,n)->(m,n)"
case
"raw"
:
self
.
gufunc_signature
=
"(m,n)->(n,m),(k),(m,n)"
case
_
:
raise
ValueError
(
f
"Invalid mode '{mode}'. Supported modes are 'full', 'economic', 'r', and 'raw'."
)
if
pivoting
:
self
.
gufunc_signature
+=
",(n)"
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
assert
x
.
ndim
==
2
,
"The input of qr function should be a matrix."
# Preserve static shape information if possible
M
,
N
=
x
.
type
.
shape
if
M
is
not
None
and
N
is
not
None
:
K
=
min
(
M
,
N
)
else
:
K
=
None
in_dtype
=
x
.
type
.
numpy_dtype
out_dtype
=
np
.
dtype
(
f
"f{in_dtype.itemsize}"
)
match
self
.
mode
:
case
"full"
:
outputs
=
[
tensor
(
shape
=
(
M
,
M
),
dtype
=
out_dtype
),
tensor
(
shape
=
(
M
,
N
),
dtype
=
out_dtype
),
]
case
"economic"
:
outputs
=
[
tensor
(
shape
=
(
M
,
K
),
dtype
=
out_dtype
),
tensor
(
shape
=
(
K
,
N
),
dtype
=
out_dtype
),
]
case
"r"
:
outputs
=
[
tensor
(
shape
=
(
M
,
N
),
dtype
=
out_dtype
),
]
case
"raw"
:
outputs
=
[
tensor
(
shape
=
(
M
,
M
),
dtype
=
out_dtype
),
tensor
(
shape
=
(
K
,),
dtype
=
out_dtype
),
tensor
(
shape
=
(
M
,
N
),
dtype
=
out_dtype
),
]
case
_
:
raise
NotImplementedError
if
self
.
pivoting
:
outputs
=
[
*
outputs
,
tensor
(
shape
=
(
N
,),
dtype
=
"int32"
)]
return
Apply
(
self
,
[
x
],
outputs
)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
(
x_shape
,)
=
shapes
M
,
N
=
x_shape
K
=
ptm
.
minimum
(
M
,
N
)
Q_shape
=
None
R_shape
=
None
tau_shape
=
None
P_shape
=
None
match
self
.
mode
:
case
"full"
:
Q_shape
=
(
M
,
M
)
R_shape
=
(
M
,
N
)
case
"economic"
:
Q_shape
=
(
M
,
K
)
R_shape
=
(
K
,
N
)
case
"r"
:
R_shape
=
(
M
,
N
)
case
"raw"
:
Q_shape
=
(
M
,
M
)
# Actually this is H in this case
tau_shape
=
(
K
,)
R_shape
=
(
M
,
N
)
if
self
.
pivoting
:
P_shape
=
(
N
,)
return
[
shape
for
shape
in
(
Q_shape
,
tau_shape
,
R_shape
,
P_shape
)
if
shape
is
not
None
]
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
not
allowed_inplace_inputs
:
return
self
new_props
=
self
.
_props_dict
()
# type: ignore
new_props
[
"overwrite_a"
]
=
True
return
type
(
self
)(
**
new_props
)
def
_call_and_get_lwork
(
self
,
fn
,
*
args
,
lwork
,
**
kwargs
):
if
lwork
in
[
-
1
,
None
]:
*
_
,
work
,
info
=
fn
(
*
args
,
lwork
=-
1
,
**
kwargs
)
lwork
=
work
.
item
()
return
fn
(
*
args
,
lwork
=
lwork
,
**
kwargs
)
def
perform
(
self
,
node
,
inputs
,
outputs
):
(
x
,)
=
inputs
M
,
N
=
x
.
shape
if
self
.
pivoting
:
(
geqp3
,)
=
get_lapack_funcs
((
"geqp3"
,),
(
x
,))
qr
,
jpvt
,
tau
,
*
work_info
=
self
.
_call_and_get_lwork
(
geqp3
,
x
,
lwork
=-
1
,
overwrite_a
=
self
.
overwrite_a
)
jpvt
-=
1
# geqp3 returns a 1-based index array, so subtract 1
else
:
(
geqrf
,)
=
get_lapack_funcs
((
"geqrf"
,),
(
x
,))
qr
,
tau
,
*
work_info
=
self
.
_call_and_get_lwork
(
geqrf
,
x
,
lwork
=-
1
,
overwrite_a
=
self
.
overwrite_a
)
if
self
.
mode
not
in
[
"economic"
,
"raw"
]
or
M
<
N
:
R
=
np
.
triu
(
qr
)
else
:
R
=
np
.
triu
(
qr
[:
N
,
:])
if
self
.
mode
==
"r"
and
self
.
pivoting
:
outputs
[
0
][
0
]
=
R
outputs
[
1
][
0
]
=
jpvt
return
elif
self
.
mode
==
"r"
:
outputs
[
0
][
0
]
=
R
return
elif
self
.
mode
==
"raw"
and
self
.
pivoting
:
outputs
[
0
][
0
]
=
qr
outputs
[
1
][
0
]
=
tau
outputs
[
2
][
0
]
=
R
outputs
[
3
][
0
]
=
jpvt
return
elif
self
.
mode
==
"raw"
:
outputs
[
0
][
0
]
=
qr
outputs
[
1
][
0
]
=
tau
outputs
[
2
][
0
]
=
R
return
(
gor_un_gqr
,)
=
get_lapack_funcs
((
"orgqr"
,),
(
qr
,))
if
M
<
N
:
Q
,
work
,
info
=
self
.
_call_and_get_lwork
(
gor_un_gqr
,
qr
[:,
:
M
],
tau
,
lwork
=-
1
,
overwrite_a
=
1
)
elif
self
.
mode
==
"economic"
:
Q
,
work
,
info
=
self
.
_call_and_get_lwork
(
gor_un_gqr
,
qr
,
tau
,
lwork
=-
1
,
overwrite_a
=
1
)
else
:
t
=
qr
.
dtype
.
char
qqr
=
np
.
empty
((
M
,
M
),
dtype
=
t
)
qqr
[:,
:
N
]
=
qr
# Always overwite qqr -- it's a meaningless intermediate value
Q
,
work
,
info
=
self
.
_call_and_get_lwork
(
gor_un_gqr
,
qqr
,
tau
,
lwork
=-
1
,
overwrite_a
=
1
)
outputs
[
0
][
0
]
=
Q
outputs
[
1
][
0
]
=
R
if
self
.
pivoting
:
outputs
[
2
][
0
]
=
jpvt
def
L_op
(
self
,
inputs
,
outputs
,
output_grads
):
"""
Reverse-mode gradient of the QR function.
References
----------
.. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
.. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
"""
from
pytensor.tensor.slinalg
import
solve_triangular
(
A
,)
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
inputs
)
m
,
n
=
A
.
shape
# Check if we have static shape info, if so we can get a better graph (avoiding the ifelse Op in the output)
M_static
,
N_static
=
A
.
type
.
shape
shapes_unknown
=
M_static
is
None
or
N_static
is
None
def
_H
(
x
:
ptb
.
TensorVariable
):
return
x
.
conj
()
.
mT
def
_copyltu
(
x
:
ptb
.
TensorVariable
):
return
ptb
.
tril
(
x
,
k
=
0
)
+
_H
(
ptb
.
tril
(
x
,
k
=-
1
))
if
self
.
mode
==
"raw"
:
raise
NotImplementedError
(
"Gradient of qr not implemented for mode=raw"
)
elif
self
.
mode
==
"r"
:
k
=
pt
.
minimum
(
m
,
n
)
# We need all the components of the QR to compute the gradient of A even if we only
# use the upper triangular component in the cost function.
props_dict
=
self
.
_props_dict
()
props_dict
[
"mode"
]
=
"economic"
props_dict
[
"pivoting"
]
=
False
qr_op
=
type
(
self
)(
**
props_dict
)
Q
,
R
=
qr_op
(
A
)
dQ
=
Q
.
zeros_like
()
# Unlike numpy.linalg.qr, scipy.linalg.qr returns the full (m,n) matrix when mode='r', *not* the (k,n)
# matrix that is computed by mode='economic'. The gradient assumes that dR is of shape (k,n), so we need to
# slice it to the first k rows. Note that if m <= n, then k = m, so this is safe in all cases.
dR
=
cast
(
ptb
.
TensorVariable
,
output_grads
[
0
][:
k
,
:])
else
:
Q
,
R
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
outputs
)
if
self
.
mode
==
"full"
:
qr_assert_op
=
Assert
(
"Gradient of qr not implemented for m x n matrices with m > n and mode=full"
)
R
=
qr_assert_op
(
R
,
ptm
.
le
(
m
,
n
))
new_output_grads
=
[]
is_disconnected
=
[
isinstance
(
x
.
type
,
DisconnectedType
)
for
x
in
output_grads
]
if
all
(
is_disconnected
):
# This should never be reached by Pytensor
return
[
DisconnectedType
()()]
# pragma: no cover
for
disconnected
,
output_grad
,
output
in
zip
(
is_disconnected
,
output_grads
,
[
Q
,
R
],
strict
=
True
):
if
disconnected
:
new_output_grads
.
append
(
output
.
zeros_like
())
else
:
new_output_grads
.
append
(
output_grad
)
(
dQ
,
dR
)
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
new_output_grads
)
if
shapes_unknown
or
M_static
>=
N_static
:
# gradient expression when m >= n
M
=
R
@
_H
(
dR
)
-
_H
(
dQ
)
@
Q
K
=
dQ
+
Q
@
_copyltu
(
M
)
A_bar_m_ge_n
=
_H
(
solve_triangular
(
R
,
_H
(
K
)))
if
not
shapes_unknown
:
return
[
A_bar_m_ge_n
]
# We have to trigger both branches if shapes_unknown is True, so this is purposefully not an elif branch
if
shapes_unknown
or
M_static
<
N_static
:
# gradient expression when m < n
Y
=
A
[:,
m
:]
U
=
R
[:,
:
m
]
dU
,
dV
=
dR
[:,
:
m
],
dR
[:,
m
:]
dQ_Yt_dV
=
dQ
+
Y
@
_H
(
dV
)
M
=
U
@
_H
(
dU
)
-
_H
(
dQ_Yt_dV
)
@
Q
X_bar
=
_H
(
solve_triangular
(
U
,
_H
(
dQ_Yt_dV
+
Q
@
_copyltu
(
M
))))
Y_bar
=
Q
@
dV
A_bar_m_lt_n
=
pt
.
concatenate
([
X_bar
,
Y_bar
],
axis
=
1
)
if
not
shapes_unknown
:
return
[
A_bar_m_lt_n
]
return
[
ifelse
(
ptm
.
ge
(
m
,
n
),
A_bar_m_ge_n
,
A_bar_m_lt_n
)]
def
qr
(
A
:
TensorLike
,
mode
:
Literal
[
"full"
,
"r"
,
"economic"
,
"raw"
,
"complete"
,
"reduced"
]
=
"full"
,
overwrite_a
:
bool
=
False
,
pivoting
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
):
"""
QR Decomposition of input matrix `a`.
The QR decomposition of a matrix `A` is a factorization of the form :math`A = QR`, where `Q` is an orthogonal
matrix (:math:`Q Q^T = I`) and `R` is an upper triangular matrix.
This decomposition is useful in various numerical methods, including solving linear systems and least squares
problems.
Parameters
----------
A: TensorLike
Input matrix of shape (M, N) to be decomposed.
mode: str, one of "full", "economic", "r", or "raw"
How the QR decomposition is computed and returned. Choosing the mode can avoid unnecessary computations,
depending on which of the return matrices are needed. Given input matrix with shape Choices are:
- "full" (or "complete"): returns `Q` and `R` with dimensions `(M, M)` and `(M, N)`.
- "economic" (or "reduced"): returns `Q` and `R` with dimensions `(M, K)` and `(K, N)`,
where `K = min(M, N)`.
- "r": returns only `R` with dimensions `(K, N)`.
- "raw": returns `H` and `tau` with dimensions `(N, M)` and `(K,)`, where `H` is the matrix of
Householder reflections, and tau is the vector of Householder coefficients.
pivoting: bool, default False
If True, also return a vector of rank-revealing permutations `P` such that `A[:, P] = QR`.
overwrite_a: bool, ignored
Ignored. Included only for consistency with the function signature of `scipy.linalg.qr`. Pytensor will always
automatically overwrite the input matrix `A` if it is safe to do sol.
lwork: int, ignored
Ignored. Included only for consistency with the function signature of `scipy.linalg.qr`. Pytensor will
automatically determine the optimal workspace size for the QR decomposition.
Returns
-------
Q or H: TensorVariable, optional
A matrix with orthonormal columns. When mode = 'complete', it is the result is an orthogonal/unitary matrix
depending on whether a is real/complex. The determinant may be either +/- 1 in that case. If
mode = 'raw', it is the matrix of Householder reflections. If mode = 'r', Q is not returned.
R or tau : TensorVariable, optional
Upper-triangular matrix. If mode = 'raw', it is the vector of Householder coefficients.
"""
# backwards compatibility from the numpy API
if
mode
==
"complete"
:
mode
=
"full"
elif
mode
==
"reduced"
:
mode
=
"economic"
return
Blockwise
(
QR
(
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
False
))(
A
)
__all__
=
[
__all__
=
[
"cholesky"
,
"cholesky"
,
"solve"
,
"solve"
,
...
@@ -1728,4 +2101,5 @@ __all__ = [
...
@@ -1728,4 +2101,5 @@ __all__ = [
"lu"
,
"lu"
,
"lu_factor"
,
"lu_factor"
,
"lu_solve"
,
"lu_solve"
,
"qr"
,
]
]
tests/link/jax/test_nlinalg.py
浏览文件 @
617964ff
...
@@ -29,12 +29,6 @@ def test_jax_basic_multiout():
...
@@ -29,12 +29,6 @@ def test_jax_basic_multiout():
outs
=
pt_nlinalg
.
eigh
(
x
)
outs
=
pt_nlinalg
.
eigh
(
x
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"full"
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"reduced"
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
svd
(
x
)
outs
=
pt_nlinalg
.
svd
(
x
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
...
...
tests/link/jax/test_slinalg.py
浏览文件 @
617964ff
...
@@ -103,6 +103,18 @@ def test_jax_basic():
...
@@ -103,6 +103,18 @@ def test_jax_basic():
],
],
)
)
def
assert_fn
(
x
,
y
):
np
.
testing
.
assert_allclose
(
x
.
astype
(
config
.
floatX
),
y
,
rtol
=
1e-3
)
M
=
rng
.
normal
(
size
=
(
3
,
3
))
X
=
M
.
dot
(
M
.
T
)
outs
=
pt_slinalg
.
qr
(
x
,
mode
=
"full"
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_slinalg
.
qr
(
x
,
mode
=
"economic"
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
def
test_jax_solve
():
def
test_jax_solve
():
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
...
...
tests/link/numba/test_nlinalg.py
浏览文件 @
617964ff
...
@@ -186,60 +186,6 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -186,60 +186,6 @@ def test_matrix_inverses(op, x, exc, op_args):
)
)
@pytest.mark.parametrize
(
"x, mode, exc"
,
[
(
(
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
"reduced"
,
None
,
),
(
(
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
"r"
,
None
,
),
(
(
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
),
),
"reduced"
,
None
,
),
(
(
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
),
),
"complete"
,
UserWarning
,
),
],
)
def
test_QRFull
(
x
,
mode
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
QRFull
(
mode
)(
x
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
compare_numba_and_py
(
[
x
],
g
,
[
test_x
],
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"x, full_matrices, compute_uv, exc"
,
"x, full_matrices, compute_uv, exc"
,
[
[
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
617964ff
...
@@ -10,6 +10,7 @@ import pytensor.tensor as pt
...
@@ -10,6 +10,7 @@ import pytensor.tensor as pt
from
pytensor
import
In
,
config
from
pytensor
import
In
,
config
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
LU
,
LU
,
QR
,
Cholesky
,
Cholesky
,
CholeskySolve
,
CholeskySolve
,
LUFactor
,
LUFactor
,
...
@@ -720,3 +721,70 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
...
@@ -720,3 +721,70 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
# Can never destroy non-contiguous inputs
# Can never destroy non-contiguous inputs
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
@pytest.mark.parametrize
(
"mode, pivoting"
,
[(
"economic"
,
False
),
(
"full"
,
True
),
(
"r"
,
False
),
(
"raw"
,
True
)],
ids
=
[
"economic"
,
"full_pivot"
,
"r"
,
"raw_pivot"
],
)
@pytest.mark.parametrize
(
"overwrite_a"
,
[
True
,
False
],
ids
=
[
"overwrite_a"
,
"no_overwrite"
]
)
def
test_qr
(
mode
,
pivoting
,
overwrite_a
):
shape
=
(
5
,
5
)
rng
=
np
.
random
.
default_rng
()
A
=
pt
.
tensor
(
"A"
,
shape
=
shape
,
dtype
=
config
.
floatX
,
)
A_val
=
rng
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
)
qr_outputs
=
pt
.
linalg
.
qr
(
A
,
mode
=
mode
,
pivoting
=
pivoting
)
fn
,
res
=
compare_numba_and_py
(
[
In
(
A
,
mutable
=
overwrite_a
)],
qr_outputs
,
[
A_val
],
numba_mode
=
numba_inplace_mode
,
inplace
=
True
,
)
op
=
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
assert
isinstance
(
op
,
QR
)
destroy_map
=
op
.
destroy_map
if
overwrite_a
:
assert
destroy_map
==
{
0
:
[
0
]}
else
:
assert
destroy_map
==
{}
# Test F-contiguous input
val_f_contig
=
np
.
copy
(
A_val
,
order
=
"F"
)
res_f_contig
=
fn
(
val_f_contig
)
for
x
,
x_f_contig
in
zip
(
res
,
res_f_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_f_contig
)
# Should always be destroyable
assert
(
A_val
==
val_f_contig
)
.
all
()
==
(
not
overwrite_a
)
# Test C-contiguous input
val_c_contig
=
np
.
copy
(
A_val
,
order
=
"C"
)
res_c_contig
=
fn
(
val_c_contig
)
for
x
,
x_c_contig
in
zip
(
res
,
res_c_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_c_contig
)
# Cannot destroy C-contiguous input
np
.
testing
.
assert_allclose
(
val_c_contig
,
A_val
)
# Test non-contiguous input
val_not_contig
=
np
.
repeat
(
A_val
,
2
,
axis
=
0
)[::
2
]
res_not_contig
=
fn
(
val_not_contig
)
for
x
,
x_not_contig
in
zip
(
res
,
res_not_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_not_contig
)
# Cannot destroy non-contiguous input
np
.
testing
.
assert_allclose
(
val_not_contig
,
A_val
)
tests/link/pytorch/conftest.py
0 → 100644
浏览文件 @
617964ff
import
numpy
as
np
import
pytest
from
pytensor
import
config
from
pytensor.tensor.type
import
matrix
@pytest.fixture
def
matrix_test
():
rng
=
np
.
random
.
default_rng
(
213234
)
M
=
rng
.
normal
(
size
=
(
3
,
3
))
test_value
=
M
.
dot
(
M
.
T
)
.
astype
(
config
.
floatX
)
x
=
matrix
(
"x"
)
return
x
,
test_value
tests/link/pytorch/test_nlinalg.py
浏览文件 @
617964ff
...
@@ -8,17 +8,6 @@ from pytensor.tensor.type import matrix
...
@@ -8,17 +8,6 @@ from pytensor.tensor.type import matrix
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
@pytest.fixture
def
matrix_test
():
rng
=
np
.
random
.
default_rng
(
213234
)
M
=
rng
.
normal
(
size
=
(
3
,
3
))
test_value
=
M
.
dot
(
M
.
T
)
.
astype
(
config
.
floatX
)
x
=
matrix
(
"x"
)
return
(
x
,
test_value
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"func"
,
"func"
,
(
pt_nla
.
eig
,
pt_nla
.
eigh
,
pt_nla
.
SLogDet
(),
pt_nla
.
inv
,
pt_nla
.
det
),
(
pt_nla
.
eig
,
pt_nla
.
eigh
,
pt_nla
.
SLogDet
(),
pt_nla
.
inv
,
pt_nla
.
det
),
...
@@ -34,22 +23,6 @@ def test_lin_alg_no_params(func, matrix_test):
...
@@ -34,22 +23,6 @@ def test_lin_alg_no_params(func, matrix_test):
compare_pytorch_and_py
([
x
],
outs
,
[
test_value
],
assert_fn
=
assert_fn
)
compare_pytorch_and_py
([
x
],
outs
,
[
test_value
],
assert_fn
=
assert_fn
)
@pytest.mark.parametrize
(
"mode"
,
(
"complete"
,
"reduced"
,
"r"
,
pytest
.
param
(
"raw"
,
marks
=
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
)),
),
)
def
test_qr
(
mode
,
matrix_test
):
x
,
test_value
=
matrix_test
outs
=
pt_nla
.
qr
(
x
,
mode
=
mode
)
compare_pytorch_and_py
([
x
],
outs
,
[
test_value
])
@pytest.mark.parametrize
(
"compute_uv"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"compute_uv"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"full_matrices"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"full_matrices"
,
[
True
,
False
])
def
test_svd
(
compute_uv
,
full_matrices
,
matrix_test
):
def
test_svd
(
compute_uv
,
full_matrices
,
matrix_test
):
...
...
tests/link/pytorch/test_slinalg.py
0 → 100644
浏览文件 @
617964ff
import
pytest
import
pytensor
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
@pytest.mark.parametrize
(
"mode"
,
(
"complete"
,
"reduced"
,
"r"
,
pytest
.
param
(
"raw"
,
marks
=
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
)),
),
)
def
test_qr
(
mode
,
matrix_test
):
x
,
test_value
=
matrix_test
outs
=
pytensor
.
tensor
.
slinalg
.
qr
(
x
,
mode
=
mode
)
compare_pytorch_and_py
([
x
],
outs
,
[
test_value
])
tests/tensor/test_nlinalg.py
浏览文件 @
617964ff
from
functools
import
partial
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
import
numpy.linalg
import
pytest
import
pytest
from
numpy.testing
import
assert_array_almost_equal
from
numpy.testing
import
assert_array_almost_equal
...
@@ -25,7 +24,6 @@ from pytensor.tensor.nlinalg import (
...
@@ -25,7 +24,6 @@ from pytensor.tensor.nlinalg import (
matrix_power
,
matrix_power
,
norm
,
norm
,
pinv
,
pinv
,
qr
,
slogdet
,
slogdet
,
svd
,
svd
,
tensorinv
,
tensorinv
,
...
@@ -122,102 +120,6 @@ def test_matrix_dot():
...
@@ -122,102 +120,6 @@ def test_matrix_dot():
assert
_allclose
(
numpy_sol
,
pytensor_sol
)
assert
_allclose
(
numpy_sol
,
pytensor_sol
)
def
test_qr_modes
():
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
matrix
(
"A"
,
dtype
=
config
.
floatX
)
a
=
rng
.
random
((
4
,
4
))
.
astype
(
config
.
floatX
)
f
=
function
([
A
],
qr
(
A
))
t_qr
=
f
(
a
)
n_qr
=
np
.
linalg
.
qr
(
a
)
assert
_allclose
(
n_qr
,
t_qr
)
for
mode
in
[
"reduced"
,
"r"
,
"raw"
]:
f
=
function
([
A
],
qr
(
A
,
mode
))
t_qr
=
f
(
a
)
n_qr
=
np
.
linalg
.
qr
(
a
,
mode
)
if
isinstance
(
n_qr
,
list
|
tuple
):
assert
_allclose
(
n_qr
[
0
],
t_qr
[
0
])
assert
_allclose
(
n_qr
[
1
],
t_qr
[
1
])
else
:
assert
_allclose
(
n_qr
,
t_qr
)
try
:
n_qr
=
np
.
linalg
.
qr
(
a
,
"complete"
)
f
=
function
([
A
],
qr
(
A
,
"complete"
))
t_qr
=
f
(
a
)
assert
_allclose
(
n_qr
,
t_qr
)
except
TypeError
as
e
:
assert
"name 'complete' is not defined"
in
str
(
e
)
@pytest.mark.parametrize
(
"shape, gradient_test_case, mode"
,
(
[(
s
,
c
,
"reduced"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
0
,
1
,
2
]]
+
[(
s
,
c
,
"complete"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
0
,
1
,
2
]]
+
[(
s
,
0
,
"r"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]]
+
[((
3
,
3
),
0
,
"raw"
)]
),
ids
=
(
[
f
"shape={s}, gradient_test_case={c}, mode=reduced"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
"Q"
,
"R"
,
"both"
]
]
+
[
f
"shape={s}, gradient_test_case={c}, mode=complete"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
"Q"
,
"R"
,
"both"
]
]
+
[
f
"shape={s}, gradient_test_case=R, mode=r"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]]
+
[
"shape=(3, 3), gradient_test_case=Q, mode=raw"
]
),
)
@pytest.mark.parametrize
(
"is_complex"
,
[
True
,
False
],
ids
=
[
"complex"
,
"real"
])
def
test_qr_grad
(
shape
,
gradient_test_case
,
mode
,
is_complex
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
def
_test_fn
(
x
,
case
=
2
,
mode
=
"reduced"
):
if
case
==
0
:
return
qr
(
x
,
mode
=
mode
)[
0
]
.
sum
()
elif
case
==
1
:
return
qr
(
x
,
mode
=
mode
)[
1
]
.
sum
()
elif
case
==
2
:
Q
,
R
=
qr
(
x
,
mode
=
mode
)
return
Q
.
sum
()
+
R
.
sum
()
if
is_complex
:
pytest
.
xfail
(
"Complex inputs currently not supported by verify_grad"
)
m
,
n
=
shape
a
=
rng
.
standard_normal
(
shape
)
.
astype
(
config
.
floatX
)
if
is_complex
:
a
+=
1
j
*
rng
.
standard_normal
(
shape
)
.
astype
(
config
.
floatX
)
if
mode
==
"raw"
:
with
pytest
.
raises
(
NotImplementedError
):
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
,
)
elif
mode
==
"complete"
and
m
>
n
:
with
pytest
.
raises
(
AssertionError
):
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
,
)
else
:
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
)
class
TestSvd
(
utt
.
InferShapeTester
):
class
TestSvd
(
utt
.
InferShapeTester
):
op_class
=
SVD
op_class
=
SVD
...
...
tests/tensor/test_slinalg.py
浏览文件 @
617964ff
import
functools
import
functools
import
itertools
import
itertools
from
functools
import
partial
from
typing
import
Literal
from
typing
import
Literal
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
scipy
import
scipy
from
scipy
import
linalg
as
scipy_linalg
from
pytensor
import
function
,
grad
from
pytensor
import
function
,
grad
from
pytensor
import
tensor
as
pt
from
pytensor
import
tensor
as
pt
...
@@ -26,6 +28,7 @@ from pytensor.tensor.slinalg import (
...
@@ -26,6 +28,7 @@ from pytensor.tensor.slinalg import (
lu_factor
,
lu_factor
,
lu_solve
,
lu_solve
,
pivot_to_permutation
,
pivot_to_permutation
,
qr
,
solve
,
solve
,
solve_continuous_lyapunov
,
solve_continuous_lyapunov
,
solve_discrete_are
,
solve_discrete_are
,
...
@@ -1088,3 +1091,104 @@ def test_block_diagonal_blockwise():
...
@@ -1088,3 +1091,104 @@ def test_block_diagonal_blockwise():
B
=
np
.
random
.
normal
(
size
=
(
1
,
batch_size
,
4
,
4
))
.
astype
(
config
.
floatX
)
B
=
np
.
random
.
normal
(
size
=
(
1
,
batch_size
,
4
,
4
))
.
astype
(
config
.
floatX
)
result
=
block_diag
(
A
,
B
)
.
eval
()
result
=
block_diag
(
A
,
B
)
.
eval
()
assert
result
.
shape
==
(
10
,
batch_size
,
6
,
6
)
assert
result
.
shape
==
(
10
,
batch_size
,
6
,
6
)
@pytest.mark.parametrize
(
"mode, names"
,
[
(
"economic"
,
[
"Q"
,
"R"
]),
(
"full"
,
[
"Q"
,
"R"
]),
(
"r"
,
[
"R"
]),
(
"raw"
,
[
"H"
,
"tau"
,
"R"
]),
],
)
@pytest.mark.parametrize
(
"pivoting"
,
[
True
,
False
])
def
test_qr_modes
(
mode
,
names
,
pivoting
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A_val
=
rng
.
random
((
4
,
4
))
.
astype
(
config
.
floatX
)
if
pivoting
:
names
=
[
*
names
,
"pivots"
]
A
=
tensor
(
"A"
,
dtype
=
config
.
floatX
,
shape
=
(
None
,
None
))
f
=
function
([
A
],
qr
(
A
,
mode
=
mode
,
pivoting
=
pivoting
))
outputs_pt
=
f
(
A_val
)
outputs_sp
=
scipy_linalg
.
qr
(
A_val
,
mode
=
mode
,
pivoting
=
pivoting
)
if
mode
==
"raw"
:
# The first output of scipy's qr is a tuple when mode is raw; flatten it for easier iteration
outputs_sp
=
(
*
outputs_sp
[
0
],
*
outputs_sp
[
1
:])
elif
mode
==
"r"
and
not
pivoting
:
# Here there's only one output from the pytensor function; wrap it in a list for iteration
outputs_pt
=
[
outputs_pt
]
for
out_pt
,
out_sp
,
name
in
zip
(
outputs_pt
,
outputs_sp
,
names
):
np
.
testing
.
assert_allclose
(
out_pt
,
out_sp
,
err_msg
=
f
"{name} disagrees"
)
@pytest.mark.parametrize
(
"shape, gradient_test_case, mode"
,
(
[(
s
,
c
,
"economic"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
0
,
1
,
2
]]
+
[(
s
,
c
,
"full"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
0
,
1
,
2
]]
+
[(
s
,
0
,
"r"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]]
+
[((
3
,
3
),
0
,
"raw"
)]
),
ids
=
(
[
f
"shape={s}, gradient_test_case={c}, mode=economic"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
"Q"
,
"R"
,
"both"
]
]
+
[
f
"shape={s}, gradient_test_case={c}, mode=full"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
"Q"
,
"R"
,
"both"
]
]
+
[
f
"shape={s}, gradient_test_case=R, mode=r"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]]
+
[
"shape=(3, 3), gradient_test_case=Q, mode=raw"
]
),
)
@pytest.mark.parametrize
(
"is_complex"
,
[
True
,
False
],
ids
=
[
"complex"
,
"real"
])
def
test_qr_grad
(
shape
,
gradient_test_case
,
mode
,
is_complex
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
def
_test_fn
(
x
,
case
=
2
,
mode
=
"reduced"
):
if
case
==
0
:
return
qr
(
x
,
mode
=
mode
)[
0
]
.
sum
()
elif
case
==
1
:
return
qr
(
x
,
mode
=
mode
)[
1
]
.
sum
()
elif
case
==
2
:
Q
,
R
=
qr
(
x
,
mode
=
mode
)
return
Q
.
sum
()
+
R
.
sum
()
if
is_complex
:
pytest
.
xfail
(
"Complex inputs currently not supported by verify_grad"
)
m
,
n
=
shape
a
=
rng
.
standard_normal
(
shape
)
.
astype
(
config
.
floatX
)
if
is_complex
:
a
+=
1
j
*
rng
.
standard_normal
(
shape
)
.
astype
(
config
.
floatX
)
if
mode
==
"raw"
:
with
pytest
.
raises
(
NotImplementedError
):
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
,
)
elif
mode
==
"full"
and
m
>
n
:
with
pytest
.
raises
(
AssertionError
):
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
,
)
else
:
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论