Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
617964ff
Unverified
提交
617964ff
authored
7月 15, 2025
作者:
Jesse Grabowski
提交者:
GitHub
7月 15, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor and update QR Op (#1518)
* Refactor QR * Update JAX QR dispatch * Update Torch QR dispatch * Update numba QR dispatch
上级
5024d54e
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
446 行增加
和
421 行删除
+446
-421
nlinalg.py
pytensor/link/jax/dispatch/nlinalg.py
+0
-11
slinalg.py
pytensor/link/jax/dispatch/slinalg.py
+11
-0
_LAPACK.py
pytensor/link/numba/dispatch/linalg/_LAPACK.py
+87
-1
qr.py
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
+0
-0
nlinalg.py
pytensor/link/numba/dispatch/nlinalg.py
+0
-36
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+104
-1
__init__.py
pytensor/link/pytorch/dispatch/__init__.py
+1
-0
nlinalg.py
pytensor/link/pytorch/dispatch/nlinalg.py
+0
-16
slinalg.py
pytensor/link/pytorch/dispatch/slinalg.py
+23
-0
nlinalg.py
pytensor/tensor/nlinalg.py
+0
-171
slinalg.py
pytensor/tensor/slinalg.py
+0
-0
test_nlinalg.py
tests/link/jax/test_nlinalg.py
+0
-6
test_slinalg.py
tests/link/jax/test_slinalg.py
+12
-0
test_nlinalg.py
tests/link/numba/test_nlinalg.py
+0
-54
test_slinalg.py
tests/link/numba/test_slinalg.py
+68
-0
conftest.py
tests/link/pytorch/conftest.py
+16
-0
test_nlinalg.py
tests/link/pytorch/test_nlinalg.py
+0
-27
test_slinalg.py
tests/link/pytorch/test_slinalg.py
+20
-0
test_nlinalg.py
tests/tensor/test_nlinalg.py
+0
-98
test_slinalg.py
tests/tensor/test_slinalg.py
+104
-0
没有找到文件。
pytensor/link/jax/dispatch/nlinalg.py
浏览文件 @
617964ff
...
@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
...
@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
KroneckerProduct
,
KroneckerProduct
,
MatrixInverse
,
MatrixInverse
,
MatrixPinv
,
MatrixPinv
,
QRFull
,
SLogDet
,
SLogDet
,
)
)
...
@@ -67,16 +66,6 @@ def jax_funcify_MatrixInverse(op, **kwargs):
...
@@ -67,16 +66,6 @@ def jax_funcify_MatrixInverse(op, **kwargs):
return
matrix_inverse
return
matrix_inverse
@jax_funcify.register
(
QRFull
)
def
jax_funcify_QRFull
(
op
,
**
kwargs
):
mode
=
op
.
mode
def
qr_full
(
x
,
mode
=
mode
):
return
jnp
.
linalg
.
qr
(
x
,
mode
=
mode
)
return
qr_full
@jax_funcify.register
(
MatrixPinv
)
@jax_funcify.register
(
MatrixPinv
)
def
jax_funcify_Pinv
(
op
,
**
kwargs
):
def
jax_funcify_Pinv
(
op
,
**
kwargs
):
def
pinv
(
x
):
def
pinv
(
x
):
...
...
pytensor/link/jax/dispatch/slinalg.py
浏览文件 @
617964ff
...
@@ -5,6 +5,7 @@ import jax
...
@@ -5,6 +5,7 @@ import jax
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
LU
,
LU
,
QR
,
BlockDiagonal
,
BlockDiagonal
,
Cholesky
,
Cholesky
,
CholeskySolve
,
CholeskySolve
,
...
@@ -168,3 +169,13 @@ def jax_funcify_ChoSolve(op, **kwargs):
...
@@ -168,3 +169,13 @@ def jax_funcify_ChoSolve(op, **kwargs):
)
)
return
cho_solve
return
cho_solve
@jax_funcify.register
(
QR
)
def
jax_funcify_QR
(
op
,
**
kwargs
):
mode
=
op
.
mode
def
qr
(
x
,
mode
=
mode
):
return
jax
.
scipy
.
linalg
.
qr
(
x
,
mode
=
mode
)
return
qr
pytensor/link/numba/dispatch/linalg/_LAPACK.py
浏览文件 @
617964ff
...
@@ -283,7 +283,6 @@ class _LAPACK:
...
@@ -283,7 +283,6 @@ class _LAPACK:
Called by scipy.linalg.lu_solve
Called by scipy.linalg.lu_solve
"""
"""
...
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"getrs"
)
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"getrs"
)
functype
=
ctypes
.
CFUNCTYPE
(
functype
=
ctypes
.
CFUNCTYPE
(
None
,
None
,
...
@@ -457,3 +456,90 @@ class _LAPACK:
...
@@ -457,3 +456,90 @@ class _LAPACK:
_ptr_int
,
# INFO
_ptr_int
,
# INFO
)
)
return
functype
(
lapack_ptr
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xgeqrf
(
cls
,
dtype
):
"""
Compute the QR factorization of a general M-by-N matrix A.
Used in QR decomposition (no pivoting).
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"geqrf"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# M
_ptr_int
,
# N
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# TAU
float_pointer
,
# WORK
_ptr_int
,
# LWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xgeqp3
(
cls
,
dtype
):
"""
Compute the QR factorization with column pivoting of a general M-by-N matrix A.
Used in QR decomposition with pivoting.
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"geqp3"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# M
_ptr_int
,
# N
float_pointer
,
# A
_ptr_int
,
# LDA
_ptr_int
,
# JPVT
float_pointer
,
# TAU
float_pointer
,
# WORK
_ptr_int
,
# LWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xorgqr
(
cls
,
dtype
):
"""
Generate the orthogonal matrix Q from a QR factorization (real types).
Used in QR decomposition to form Q.
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"orgqr"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# M
_ptr_int
,
# N
_ptr_int
,
# K
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# TAU
float_pointer
,
# WORK
_ptr_int
,
# LWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
@classmethod
def
numba_xungqr
(
cls
,
dtype
):
"""
Generate the unitary matrix Q from a QR factorization (complex types).
Used in QR decomposition to form Q for complex types.
"""
lapack_ptr
,
float_pointer
=
_get_lapack_ptr_and_ptr_type
(
dtype
,
"ungqr"
)
functype
=
ctypes
.
CFUNCTYPE
(
None
,
_ptr_int
,
# M
_ptr_int
,
# N
_ptr_int
,
# K
float_pointer
,
# A
_ptr_int
,
# LDA
float_pointer
,
# TAU
float_pointer
,
# WORK
_ptr_int
,
# LWORK
_ptr_int
,
# INFO
)
return
functype
(
lapack_ptr
)
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
0 → 100644
浏览文件 @
617964ff
差异被折叠。
点击展开。
pytensor/link/numba/dispatch/nlinalg.py
浏览文件 @
617964ff
...
@@ -16,7 +16,6 @@ from pytensor.tensor.nlinalg import (
...
@@ -16,7 +16,6 @@ from pytensor.tensor.nlinalg import (
Eigh
,
Eigh
,
MatrixInverse
,
MatrixInverse
,
MatrixPinv
,
MatrixPinv
,
QRFull
,
SLogDet
,
SLogDet
,
)
)
...
@@ -146,38 +145,3 @@ def numba_funcify_MatrixPinv(op, node, **kwargs):
...
@@ -146,38 +145,3 @@ def numba_funcify_MatrixPinv(op, node, **kwargs):
return
np
.
linalg
.
pinv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
np
.
linalg
.
pinv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
matrixpinv
return
matrixpinv
@numba_funcify.register
(
QRFull
)
def
numba_funcify_QRFull
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
if
mode
!=
"reduced"
:
warnings
.
warn
(
(
"Numba will use object mode to allow the "
"`mode` argument to `numpy.linalg.qr`."
),
UserWarning
,
)
if
len
(
node
.
outputs
)
>
1
:
ret_sig
=
numba
.
types
.
Tuple
([
get_numba_type
(
o
.
type
)
for
o
in
node
.
outputs
])
else
:
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
@numba_basic.numba_njit
def
qr_full
(
x
):
with
numba
.
objmode
(
ret
=
ret_sig
):
ret
=
np
.
linalg
.
qr
(
x
,
mode
=
mode
)
return
ret
else
:
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
qr_full
(
x
):
return
np
.
linalg
.
qr
(
inputs_cast
(
x
))
return
qr_full
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
617964ff
...
@@ -2,6 +2,7 @@ import warnings
...
@@ -2,6 +2,7 @@ import warnings
import
numpy
as
np
import
numpy
as
np
from
pytensor
import
config
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
...
@@ -11,6 +12,14 @@ from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
...
@@ -11,6 +12,14 @@ from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
_pivot_to_permutation
,
_pivot_to_permutation
,
)
)
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_lu_factor
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_lu_factor
from
pytensor.link.numba.dispatch.linalg.decomposition.qr
import
(
_qr_full_no_pivot
,
_qr_full_pivot
,
_qr_r_no_pivot
,
_qr_r_pivot
,
_qr_raw_no_pivot
,
_qr_raw_pivot
,
)
from
pytensor.link.numba.dispatch.linalg.solve.cholesky
import
_cho_solve
from
pytensor.link.numba.dispatch.linalg.solve.cholesky
import
_cho_solve
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_solve_gen
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_solve_gen
from
pytensor.link.numba.dispatch.linalg.solve.posdef
import
_solve_psd
from
pytensor.link.numba.dispatch.linalg.solve.posdef
import
_solve_psd
...
@@ -19,6 +28,7 @@ from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangul
...
@@ -19,6 +28,7 @@ from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangul
from
pytensor.link.numba.dispatch.linalg.solve.tridiagonal
import
_solve_tridiagonal
from
pytensor.link.numba.dispatch.linalg.solve.tridiagonal
import
_solve_tridiagonal
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
LU
,
LU
,
QR
,
BlockDiagonal
,
BlockDiagonal
,
Cholesky
,
Cholesky
,
CholeskySolve
,
CholeskySolve
,
...
@@ -27,7 +37,7 @@ from pytensor.tensor.slinalg import (
...
@@ -27,7 +37,7 @@ from pytensor.tensor.slinalg import (
Solve
,
Solve
,
SolveTriangular
,
SolveTriangular
,
)
)
from
pytensor.tensor.type
import
complex_dtypes
from
pytensor.tensor.type
import
complex_dtypes
,
integer_dtypes
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
=
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
=
(
...
@@ -311,3 +321,96 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
...
@@ -311,3 +321,96 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
)
)
return
cho_solve
return
cho_solve
@numba_funcify.register
(
QR
)
def
numba_funcify_QR
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
check_finite
=
op
.
check_finite
pivoting
=
op
.
pivoting
overwrite_a
=
op
.
overwrite_a
dtype
=
node
.
inputs
[
0
]
.
dtype
if
dtype
in
complex_dtypes
:
raise
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
integer_input
=
dtype
in
integer_dtypes
in_dtype
=
config
.
floatX
if
integer_input
else
dtype
@numba_njit
(
cache
=
False
)
def
qr
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to qr"
)
if
integer_input
:
a
=
a
.
astype
(
in_dtype
)
if
(
mode
==
"full"
or
mode
==
"economic"
)
and
pivoting
:
Q
,
R
,
P
=
_qr_full_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
Q
,
R
,
P
elif
(
mode
==
"full"
or
mode
==
"economic"
)
and
not
pivoting
:
Q
,
R
=
_qr_full_no_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
Q
,
R
elif
mode
==
"r"
and
pivoting
:
R
,
P
=
_qr_r_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
R
,
P
elif
mode
==
"r"
and
not
pivoting
:
(
R
,)
=
_qr_r_no_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
R
elif
mode
==
"raw"
and
pivoting
:
H
,
tau
,
R
,
P
=
_qr_raw_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
H
,
tau
,
R
,
P
elif
mode
==
"raw"
and
not
pivoting
:
H
,
tau
,
R
=
_qr_raw_no_pivot
(
a
,
mode
=
mode
,
pivoting
=
pivoting
,
overwrite_a
=
overwrite_a
,
check_finite
=
check_finite
,
)
return
H
,
tau
,
R
else
:
raise
NotImplementedError
(
f
"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
)
return
qr
pytensor/link/pytorch/dispatch/__init__.py
浏览文件 @
617964ff
...
@@ -8,6 +8,7 @@ import pytensor.link.pytorch.dispatch.elemwise
...
@@ -8,6 +8,7 @@ import pytensor.link.pytorch.dispatch.elemwise
import
pytensor.link.pytorch.dispatch.math
import
pytensor.link.pytorch.dispatch.math
import
pytensor.link.pytorch.dispatch.extra_ops
import
pytensor.link.pytorch.dispatch.extra_ops
import
pytensor.link.pytorch.dispatch.nlinalg
import
pytensor.link.pytorch.dispatch.nlinalg
import
pytensor.link.pytorch.dispatch.slinalg
import
pytensor.link.pytorch.dispatch.shape
import
pytensor.link.pytorch.dispatch.shape
import
pytensor.link.pytorch.dispatch.sort
import
pytensor.link.pytorch.dispatch.sort
import
pytensor.link.pytorch.dispatch.subtensor
import
pytensor.link.pytorch.dispatch.subtensor
...
...
pytensor/link/pytorch/dispatch/nlinalg.py
浏览文件 @
617964ff
...
@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
...
@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
KroneckerProduct
,
KroneckerProduct
,
MatrixInverse
,
MatrixInverse
,
MatrixPinv
,
MatrixPinv
,
QRFull
,
SLogDet
,
SLogDet
,
)
)
...
@@ -70,21 +69,6 @@ def pytorch_funcify_MatrixInverse(op, **kwargs):
...
@@ -70,21 +69,6 @@ def pytorch_funcify_MatrixInverse(op, **kwargs):
return
matrix_inverse
return
matrix_inverse
@pytorch_funcify.register
(
QRFull
)
def
pytorch_funcify_QRFull
(
op
,
**
kwargs
):
mode
=
op
.
mode
if
mode
==
"raw"
:
raise
NotImplementedError
(
"raw mode not implemented in PyTorch"
)
def
qr_full
(
x
):
Q
,
R
=
torch
.
linalg
.
qr
(
x
,
mode
=
mode
)
if
mode
==
"r"
:
return
R
return
Q
,
R
return
qr_full
@pytorch_funcify.register
(
MatrixPinv
)
@pytorch_funcify.register
(
MatrixPinv
)
def
pytorch_funcify_Pinv
(
op
,
**
kwargs
):
def
pytorch_funcify_Pinv
(
op
,
**
kwargs
):
hermitian
=
op
.
hermitian
hermitian
=
op
.
hermitian
...
...
pytensor/link/pytorch/dispatch/slinalg.py
0 → 100644
浏览文件 @
617964ff
import
torch
from
pytensor.link.pytorch.dispatch
import
pytorch_funcify
from
pytensor.tensor.slinalg
import
QR
@pytorch_funcify.register
(
QR
)
def
pytorch_funcify_QR
(
op
,
**
kwargs
):
mode
=
op
.
mode
if
mode
==
"raw"
:
raise
NotImplementedError
(
"raw mode not implemented in PyTorch"
)
elif
mode
==
"full"
:
mode
=
"complete"
elif
mode
==
"economic"
:
mode
=
"reduced"
def
qr
(
x
):
Q
,
R
=
torch
.
linalg
.
qr
(
x
,
mode
=
mode
)
if
mode
==
"r"
:
return
R
return
Q
,
R
return
qr
pytensor/tensor/nlinalg.py
浏览文件 @
617964ff
...
@@ -5,15 +5,12 @@ from typing import Literal, cast
...
@@ -5,15 +5,12 @@ from typing import Literal, cast
import
numpy
as
np
import
numpy
as
np
import
pytensor.tensor
as
pt
from
pytensor
import
scalar
as
ps
from
pytensor
import
scalar
as
ps
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.gradient
import
DisconnectedType
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.ifelse
import
ifelse
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.raise_op
import
Assert
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
math
as
ptm
from
pytensor.tensor
import
math
as
ptm
...
@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"):
...
@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"):
return
Eigh
(
UPLO
)(
a
)
return
Eigh
(
UPLO
)(
a
)
class
QRFull
(
Op
):
"""
Full QR Decomposition.
Computes the QR decomposition of a matrix.
Factor the matrix a as qr, where q is orthonormal
and r is upper-triangular.
"""
__props__
=
(
"mode"
,)
def
__init__
(
self
,
mode
):
self
.
mode
=
mode
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
assert
x
.
ndim
==
2
,
"The input of qr function should be a matrix."
in_dtype
=
x
.
type
.
numpy_dtype
out_dtype
=
np
.
dtype
(
f
"f{in_dtype.itemsize}"
)
q
=
matrix
(
dtype
=
out_dtype
)
if
self
.
mode
!=
"raw"
:
r
=
matrix
(
dtype
=
out_dtype
)
else
:
r
=
vector
(
dtype
=
out_dtype
)
if
self
.
mode
!=
"r"
:
q
=
matrix
(
dtype
=
out_dtype
)
outputs
=
[
q
,
r
]
else
:
outputs
=
[
r
]
return
Apply
(
self
,
[
x
],
outputs
)
def
perform
(
self
,
node
,
inputs
,
outputs
):
(
x
,)
=
inputs
assert
x
.
ndim
==
2
,
"The input of qr function should be a matrix."
res
=
np
.
linalg
.
qr
(
x
,
self
.
mode
)
if
self
.
mode
!=
"r"
:
outputs
[
0
][
0
],
outputs
[
1
][
0
]
=
res
else
:
outputs
[
0
][
0
]
=
res
def
L_op
(
self
,
inputs
,
outputs
,
output_grads
):
"""
Reverse-mode gradient of the QR function.
References
----------
.. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
.. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
"""
from
pytensor.tensor.slinalg
import
solve_triangular
(
A
,)
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
inputs
)
m
,
n
=
A
.
shape
def
_H
(
x
:
ptb
.
TensorVariable
):
return
x
.
conj
()
.
mT
def
_copyltu
(
x
:
ptb
.
TensorVariable
):
return
ptb
.
tril
(
x
,
k
=
0
)
+
_H
(
ptb
.
tril
(
x
,
k
=-
1
))
if
self
.
mode
==
"raw"
:
raise
NotImplementedError
(
"Gradient of qr not implemented for mode=raw"
)
elif
self
.
mode
==
"r"
:
# We need all the components of the QR to compute the gradient of A even if we only
# use the upper triangular component in the cost function.
Q
,
R
=
qr
(
A
,
mode
=
"reduced"
)
dQ
=
Q
.
zeros_like
()
dR
=
cast
(
ptb
.
TensorVariable
,
output_grads
[
0
])
else
:
Q
,
R
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
outputs
)
if
self
.
mode
==
"complete"
:
qr_assert_op
=
Assert
(
"Gradient of qr not implemented for m x n matrices with m > n and mode=complete"
)
R
=
qr_assert_op
(
R
,
ptm
.
le
(
m
,
n
))
new_output_grads
=
[]
is_disconnected
=
[
isinstance
(
x
.
type
,
DisconnectedType
)
for
x
in
output_grads
]
if
all
(
is_disconnected
):
# This should never be reached by Pytensor
return
[
DisconnectedType
()()]
# pragma: no cover
for
disconnected
,
output_grad
,
output
in
zip
(
is_disconnected
,
output_grads
,
[
Q
,
R
],
strict
=
True
):
if
disconnected
:
new_output_grads
.
append
(
output
.
zeros_like
())
else
:
new_output_grads
.
append
(
output_grad
)
(
dQ
,
dR
)
=
(
cast
(
ptb
.
TensorVariable
,
x
)
for
x
in
new_output_grads
)
# gradient expression when m >= n
M
=
R
@
_H
(
dR
)
-
_H
(
dQ
)
@
Q
K
=
dQ
+
Q
@
_copyltu
(
M
)
A_bar_m_ge_n
=
_H
(
solve_triangular
(
R
,
_H
(
K
)))
# gradient expression when m < n
Y
=
A
[:,
m
:]
U
=
R
[:,
:
m
]
dU
,
dV
=
dR
[:,
:
m
],
dR
[:,
m
:]
dQ_Yt_dV
=
dQ
+
Y
@
_H
(
dV
)
M
=
U
@
_H
(
dU
)
-
_H
(
dQ_Yt_dV
)
@
Q
X_bar
=
_H
(
solve_triangular
(
U
,
_H
(
dQ_Yt_dV
+
Q
@
_copyltu
(
M
))))
Y_bar
=
Q
@
dV
A_bar_m_lt_n
=
pt
.
concatenate
([
X_bar
,
Y_bar
],
axis
=
1
)
return
[
ifelse
(
ptm
.
ge
(
m
,
n
),
A_bar_m_ge_n
,
A_bar_m_lt_n
)]
def
qr
(
a
,
mode
=
"reduced"
):
"""
Computes the QR decomposition of a matrix.
Factor the matrix a as qr, where q
is orthonormal and r is upper-triangular.
Parameters
----------
a : array_like, shape (M, N)
Matrix to be factored.
mode : {'reduced', 'complete', 'r', 'raw'}, optional
If K = min(M, N), then
'reduced'
returns q, r with dimensions (M, K), (K, N)
'complete'
returns q, r with dimensions (M, M), (M, N)
'r'
returns r only with dimensions (K, N)
'raw'
returns h, tau with dimensions (N, M), (K,)
Note that array h returned in 'raw' mode is
transposed for calling Fortran.
Default mode is 'reduced'
Returns
-------
q : matrix of float or complex, optional
A matrix with orthonormal columns. When mode = 'complete' the
result is an orthogonal/unitary matrix depending on whether or
not a is real/complex. The determinant may be either +/- 1 in
that case.
r : matrix of float or complex, optional
The upper-triangular matrix.
"""
return
QRFull
(
mode
)(
a
)
class
SVD
(
Op
):
class
SVD
(
Op
):
"""
"""
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
...
@@ -1291,7 +1121,6 @@ __all__ = [
...
@@ -1291,7 +1121,6 @@ __all__ = [
"det"
,
"det"
,
"eig"
,
"eig"
,
"eigh"
,
"eigh"
,
"qr"
,
"svd"
,
"svd"
,
"lstsq"
,
"lstsq"
,
"matrix_power"
,
"matrix_power"
,
...
...
pytensor/tensor/slinalg.py
浏览文件 @
617964ff
差异被折叠。
点击展开。
tests/link/jax/test_nlinalg.py
浏览文件 @
617964ff
...
@@ -29,12 +29,6 @@ def test_jax_basic_multiout():
...
@@ -29,12 +29,6 @@ def test_jax_basic_multiout():
outs
=
pt_nlinalg
.
eigh
(
x
)
outs
=
pt_nlinalg
.
eigh
(
x
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"full"
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"reduced"
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
svd
(
x
)
outs
=
pt_nlinalg
.
svd
(
x
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
...
...
tests/link/jax/test_slinalg.py
浏览文件 @
617964ff
...
@@ -103,6 +103,18 @@ def test_jax_basic():
...
@@ -103,6 +103,18 @@ def test_jax_basic():
],
],
)
)
def
assert_fn
(
x
,
y
):
np
.
testing
.
assert_allclose
(
x
.
astype
(
config
.
floatX
),
y
,
rtol
=
1e-3
)
M
=
rng
.
normal
(
size
=
(
3
,
3
))
X
=
M
.
dot
(
M
.
T
)
outs
=
pt_slinalg
.
qr
(
x
,
mode
=
"full"
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_slinalg
.
qr
(
x
,
mode
=
"economic"
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
def
test_jax_solve
():
def
test_jax_solve
():
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
...
...
tests/link/numba/test_nlinalg.py
浏览文件 @
617964ff
...
@@ -186,60 +186,6 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -186,60 +186,6 @@ def test_matrix_inverses(op, x, exc, op_args):
)
)
@pytest.mark.parametrize
(
"x, mode, exc"
,
[
(
(
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
"reduced"
,
None
,
),
(
(
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
"r"
,
None
,
),
(
(
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
),
),
"reduced"
,
None
,
),
(
(
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
),
),
"complete"
,
UserWarning
,
),
],
)
def
test_QRFull
(
x
,
mode
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
QRFull
(
mode
)(
x
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
compare_numba_and_py
(
[
x
],
g
,
[
test_x
],
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"x, full_matrices, compute_uv, exc"
,
"x, full_matrices, compute_uv, exc"
,
[
[
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
617964ff
...
@@ -10,6 +10,7 @@ import pytensor.tensor as pt
...
@@ -10,6 +10,7 @@ import pytensor.tensor as pt
from
pytensor
import
In
,
config
from
pytensor
import
In
,
config
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
LU
,
LU
,
QR
,
Cholesky
,
Cholesky
,
CholeskySolve
,
CholeskySolve
,
LUFactor
,
LUFactor
,
...
@@ -720,3 +721,70 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
...
@@ -720,3 +721,70 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
# Can never destroy non-contiguous inputs
# Can never destroy non-contiguous inputs
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
@pytest.mark.parametrize
(
"mode, pivoting"
,
[(
"economic"
,
False
),
(
"full"
,
True
),
(
"r"
,
False
),
(
"raw"
,
True
)],
ids
=
[
"economic"
,
"full_pivot"
,
"r"
,
"raw_pivot"
],
)
@pytest.mark.parametrize
(
"overwrite_a"
,
[
True
,
False
],
ids
=
[
"overwrite_a"
,
"no_overwrite"
]
)
def
test_qr
(
mode
,
pivoting
,
overwrite_a
):
shape
=
(
5
,
5
)
rng
=
np
.
random
.
default_rng
()
A
=
pt
.
tensor
(
"A"
,
shape
=
shape
,
dtype
=
config
.
floatX
,
)
A_val
=
rng
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
)
qr_outputs
=
pt
.
linalg
.
qr
(
A
,
mode
=
mode
,
pivoting
=
pivoting
)
fn
,
res
=
compare_numba_and_py
(
[
In
(
A
,
mutable
=
overwrite_a
)],
qr_outputs
,
[
A_val
],
numba_mode
=
numba_inplace_mode
,
inplace
=
True
,
)
op
=
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
assert
isinstance
(
op
,
QR
)
destroy_map
=
op
.
destroy_map
if
overwrite_a
:
assert
destroy_map
==
{
0
:
[
0
]}
else
:
assert
destroy_map
==
{}
# Test F-contiguous input
val_f_contig
=
np
.
copy
(
A_val
,
order
=
"F"
)
res_f_contig
=
fn
(
val_f_contig
)
for
x
,
x_f_contig
in
zip
(
res
,
res_f_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_f_contig
)
# Should always be destroyable
assert
(
A_val
==
val_f_contig
)
.
all
()
==
(
not
overwrite_a
)
# Test C-contiguous input
val_c_contig
=
np
.
copy
(
A_val
,
order
=
"C"
)
res_c_contig
=
fn
(
val_c_contig
)
for
x
,
x_c_contig
in
zip
(
res
,
res_c_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_c_contig
)
# Cannot destroy C-contiguous input
np
.
testing
.
assert_allclose
(
val_c_contig
,
A_val
)
# Test non-contiguous input
val_not_contig
=
np
.
repeat
(
A_val
,
2
,
axis
=
0
)[::
2
]
res_not_contig
=
fn
(
val_not_contig
)
for
x
,
x_not_contig
in
zip
(
res
,
res_not_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_not_contig
)
# Cannot destroy non-contiguous input
np
.
testing
.
assert_allclose
(
val_not_contig
,
A_val
)
tests/link/pytorch/conftest.py
0 → 100644
浏览文件 @
617964ff
import
numpy
as
np
import
pytest
from
pytensor
import
config
from
pytensor.tensor.type
import
matrix
@pytest.fixture
def
matrix_test
():
rng
=
np
.
random
.
default_rng
(
213234
)
M
=
rng
.
normal
(
size
=
(
3
,
3
))
test_value
=
M
.
dot
(
M
.
T
)
.
astype
(
config
.
floatX
)
x
=
matrix
(
"x"
)
return
x
,
test_value
tests/link/pytorch/test_nlinalg.py
浏览文件 @
617964ff
...
@@ -8,17 +8,6 @@ from pytensor.tensor.type import matrix
...
@@ -8,17 +8,6 @@ from pytensor.tensor.type import matrix
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
@pytest.fixture
def
matrix_test
():
rng
=
np
.
random
.
default_rng
(
213234
)
M
=
rng
.
normal
(
size
=
(
3
,
3
))
test_value
=
M
.
dot
(
M
.
T
)
.
astype
(
config
.
floatX
)
x
=
matrix
(
"x"
)
return
(
x
,
test_value
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"func"
,
"func"
,
(
pt_nla
.
eig
,
pt_nla
.
eigh
,
pt_nla
.
SLogDet
(),
pt_nla
.
inv
,
pt_nla
.
det
),
(
pt_nla
.
eig
,
pt_nla
.
eigh
,
pt_nla
.
SLogDet
(),
pt_nla
.
inv
,
pt_nla
.
det
),
...
@@ -34,22 +23,6 @@ def test_lin_alg_no_params(func, matrix_test):
...
@@ -34,22 +23,6 @@ def test_lin_alg_no_params(func, matrix_test):
compare_pytorch_and_py
([
x
],
outs
,
[
test_value
],
assert_fn
=
assert_fn
)
compare_pytorch_and_py
([
x
],
outs
,
[
test_value
],
assert_fn
=
assert_fn
)
@pytest.mark.parametrize
(
"mode"
,
(
"complete"
,
"reduced"
,
"r"
,
pytest
.
param
(
"raw"
,
marks
=
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
)),
),
)
def
test_qr
(
mode
,
matrix_test
):
x
,
test_value
=
matrix_test
outs
=
pt_nla
.
qr
(
x
,
mode
=
mode
)
compare_pytorch_and_py
([
x
],
outs
,
[
test_value
])
@pytest.mark.parametrize
(
"compute_uv"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"compute_uv"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"full_matrices"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"full_matrices"
,
[
True
,
False
])
def
test_svd
(
compute_uv
,
full_matrices
,
matrix_test
):
def
test_svd
(
compute_uv
,
full_matrices
,
matrix_test
):
...
...
tests/link/pytorch/test_slinalg.py
0 → 100644
浏览文件 @
617964ff
import
pytest
import
pytensor
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
@pytest.mark.parametrize
(
"mode"
,
(
"complete"
,
"reduced"
,
"r"
,
pytest
.
param
(
"raw"
,
marks
=
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
)),
),
)
def
test_qr
(
mode
,
matrix_test
):
x
,
test_value
=
matrix_test
outs
=
pytensor
.
tensor
.
slinalg
.
qr
(
x
,
mode
=
mode
)
compare_pytorch_and_py
([
x
],
outs
,
[
test_value
])
tests/tensor/test_nlinalg.py
浏览文件 @
617964ff
from
functools
import
partial
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
import
numpy.linalg
import
pytest
import
pytest
from
numpy.testing
import
assert_array_almost_equal
from
numpy.testing
import
assert_array_almost_equal
...
@@ -25,7 +24,6 @@ from pytensor.tensor.nlinalg import (
...
@@ -25,7 +24,6 @@ from pytensor.tensor.nlinalg import (
matrix_power
,
matrix_power
,
norm
,
norm
,
pinv
,
pinv
,
qr
,
slogdet
,
slogdet
,
svd
,
svd
,
tensorinv
,
tensorinv
,
...
@@ -122,102 +120,6 @@ def test_matrix_dot():
...
@@ -122,102 +120,6 @@ def test_matrix_dot():
assert
_allclose
(
numpy_sol
,
pytensor_sol
)
assert
_allclose
(
numpy_sol
,
pytensor_sol
)
def
test_qr_modes
():
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
matrix
(
"A"
,
dtype
=
config
.
floatX
)
a
=
rng
.
random
((
4
,
4
))
.
astype
(
config
.
floatX
)
f
=
function
([
A
],
qr
(
A
))
t_qr
=
f
(
a
)
n_qr
=
np
.
linalg
.
qr
(
a
)
assert
_allclose
(
n_qr
,
t_qr
)
for
mode
in
[
"reduced"
,
"r"
,
"raw"
]:
f
=
function
([
A
],
qr
(
A
,
mode
))
t_qr
=
f
(
a
)
n_qr
=
np
.
linalg
.
qr
(
a
,
mode
)
if
isinstance
(
n_qr
,
list
|
tuple
):
assert
_allclose
(
n_qr
[
0
],
t_qr
[
0
])
assert
_allclose
(
n_qr
[
1
],
t_qr
[
1
])
else
:
assert
_allclose
(
n_qr
,
t_qr
)
try
:
n_qr
=
np
.
linalg
.
qr
(
a
,
"complete"
)
f
=
function
([
A
],
qr
(
A
,
"complete"
))
t_qr
=
f
(
a
)
assert
_allclose
(
n_qr
,
t_qr
)
except
TypeError
as
e
:
assert
"name 'complete' is not defined"
in
str
(
e
)
@pytest.mark.parametrize
(
"shape, gradient_test_case, mode"
,
(
[(
s
,
c
,
"reduced"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
0
,
1
,
2
]]
+
[(
s
,
c
,
"complete"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
0
,
1
,
2
]]
+
[(
s
,
0
,
"r"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]]
+
[((
3
,
3
),
0
,
"raw"
)]
),
ids
=
(
[
f
"shape={s}, gradient_test_case={c}, mode=reduced"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
"Q"
,
"R"
,
"both"
]
]
+
[
f
"shape={s}, gradient_test_case={c}, mode=complete"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
"Q"
,
"R"
,
"both"
]
]
+
[
f
"shape={s}, gradient_test_case=R, mode=r"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]]
+
[
"shape=(3, 3), gradient_test_case=Q, mode=raw"
]
),
)
@pytest.mark.parametrize
(
"is_complex"
,
[
True
,
False
],
ids
=
[
"complex"
,
"real"
])
def
test_qr_grad
(
shape
,
gradient_test_case
,
mode
,
is_complex
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
def
_test_fn
(
x
,
case
=
2
,
mode
=
"reduced"
):
if
case
==
0
:
return
qr
(
x
,
mode
=
mode
)[
0
]
.
sum
()
elif
case
==
1
:
return
qr
(
x
,
mode
=
mode
)[
1
]
.
sum
()
elif
case
==
2
:
Q
,
R
=
qr
(
x
,
mode
=
mode
)
return
Q
.
sum
()
+
R
.
sum
()
if
is_complex
:
pytest
.
xfail
(
"Complex inputs currently not supported by verify_grad"
)
m
,
n
=
shape
a
=
rng
.
standard_normal
(
shape
)
.
astype
(
config
.
floatX
)
if
is_complex
:
a
+=
1
j
*
rng
.
standard_normal
(
shape
)
.
astype
(
config
.
floatX
)
if
mode
==
"raw"
:
with
pytest
.
raises
(
NotImplementedError
):
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
,
)
elif
mode
==
"complete"
and
m
>
n
:
with
pytest
.
raises
(
AssertionError
):
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
,
)
else
:
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
)
class
TestSvd
(
utt
.
InferShapeTester
):
class
TestSvd
(
utt
.
InferShapeTester
):
op_class
=
SVD
op_class
=
SVD
...
...
tests/tensor/test_slinalg.py
浏览文件 @
617964ff
import
functools
import
functools
import
itertools
import
itertools
from
functools
import
partial
from
typing
import
Literal
from
typing
import
Literal
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
scipy
import
scipy
from
scipy
import
linalg
as
scipy_linalg
from
pytensor
import
function
,
grad
from
pytensor
import
function
,
grad
from
pytensor
import
tensor
as
pt
from
pytensor
import
tensor
as
pt
...
@@ -26,6 +28,7 @@ from pytensor.tensor.slinalg import (
...
@@ -26,6 +28,7 @@ from pytensor.tensor.slinalg import (
lu_factor
,
lu_factor
,
lu_solve
,
lu_solve
,
pivot_to_permutation
,
pivot_to_permutation
,
qr
,
solve
,
solve
,
solve_continuous_lyapunov
,
solve_continuous_lyapunov
,
solve_discrete_are
,
solve_discrete_are
,
...
@@ -1088,3 +1091,104 @@ def test_block_diagonal_blockwise():
...
@@ -1088,3 +1091,104 @@ def test_block_diagonal_blockwise():
B
=
np
.
random
.
normal
(
size
=
(
1
,
batch_size
,
4
,
4
))
.
astype
(
config
.
floatX
)
B
=
np
.
random
.
normal
(
size
=
(
1
,
batch_size
,
4
,
4
))
.
astype
(
config
.
floatX
)
result
=
block_diag
(
A
,
B
)
.
eval
()
result
=
block_diag
(
A
,
B
)
.
eval
()
assert
result
.
shape
==
(
10
,
batch_size
,
6
,
6
)
assert
result
.
shape
==
(
10
,
batch_size
,
6
,
6
)
@pytest.mark.parametrize
(
"mode, names"
,
[
(
"economic"
,
[
"Q"
,
"R"
]),
(
"full"
,
[
"Q"
,
"R"
]),
(
"r"
,
[
"R"
]),
(
"raw"
,
[
"H"
,
"tau"
,
"R"
]),
],
)
@pytest.mark.parametrize
(
"pivoting"
,
[
True
,
False
])
def
test_qr_modes
(
mode
,
names
,
pivoting
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A_val
=
rng
.
random
((
4
,
4
))
.
astype
(
config
.
floatX
)
if
pivoting
:
names
=
[
*
names
,
"pivots"
]
A
=
tensor
(
"A"
,
dtype
=
config
.
floatX
,
shape
=
(
None
,
None
))
f
=
function
([
A
],
qr
(
A
,
mode
=
mode
,
pivoting
=
pivoting
))
outputs_pt
=
f
(
A_val
)
outputs_sp
=
scipy_linalg
.
qr
(
A_val
,
mode
=
mode
,
pivoting
=
pivoting
)
if
mode
==
"raw"
:
# The first output of scipy's qr is a tuple when mode is raw; flatten it for easier iteration
outputs_sp
=
(
*
outputs_sp
[
0
],
*
outputs_sp
[
1
:])
elif
mode
==
"r"
and
not
pivoting
:
# Here there's only one output from the pytensor function; wrap it in a list for iteration
outputs_pt
=
[
outputs_pt
]
for
out_pt
,
out_sp
,
name
in
zip
(
outputs_pt
,
outputs_sp
,
names
):
np
.
testing
.
assert_allclose
(
out_pt
,
out_sp
,
err_msg
=
f
"{name} disagrees"
)
@pytest.mark.parametrize
(
"shape, gradient_test_case, mode"
,
(
[(
s
,
c
,
"economic"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
0
,
1
,
2
]]
+
[(
s
,
c
,
"full"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
0
,
1
,
2
]]
+
[(
s
,
0
,
"r"
)
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]]
+
[((
3
,
3
),
0
,
"raw"
)]
),
ids
=
(
[
f
"shape={s}, gradient_test_case={c}, mode=economic"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
"Q"
,
"R"
,
"both"
]
]
+
[
f
"shape={s}, gradient_test_case={c}, mode=full"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]
for
c
in
[
"Q"
,
"R"
,
"both"
]
]
+
[
f
"shape={s}, gradient_test_case=R, mode=r"
for
s
in
[(
3
,
3
),
(
6
,
3
),
(
3
,
6
)]]
+
[
"shape=(3, 3), gradient_test_case=Q, mode=raw"
]
),
)
@pytest.mark.parametrize
(
"is_complex"
,
[
True
,
False
],
ids
=
[
"complex"
,
"real"
])
def
test_qr_grad
(
shape
,
gradient_test_case
,
mode
,
is_complex
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
def
_test_fn
(
x
,
case
=
2
,
mode
=
"reduced"
):
if
case
==
0
:
return
qr
(
x
,
mode
=
mode
)[
0
]
.
sum
()
elif
case
==
1
:
return
qr
(
x
,
mode
=
mode
)[
1
]
.
sum
()
elif
case
==
2
:
Q
,
R
=
qr
(
x
,
mode
=
mode
)
return
Q
.
sum
()
+
R
.
sum
()
if
is_complex
:
pytest
.
xfail
(
"Complex inputs currently not supported by verify_grad"
)
m
,
n
=
shape
a
=
rng
.
standard_normal
(
shape
)
.
astype
(
config
.
floatX
)
if
is_complex
:
a
+=
1
j
*
rng
.
standard_normal
(
shape
)
.
astype
(
config
.
floatX
)
if
mode
==
"raw"
:
with
pytest
.
raises
(
NotImplementedError
):
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
,
)
elif
mode
==
"full"
and
m
>
n
:
with
pytest
.
raises
(
AssertionError
):
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
,
)
else
:
utt
.
verify_grad
(
partial
(
_test_fn
,
case
=
gradient_test_case
,
mode
=
mode
),
[
a
],
rng
=
np
.
random
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论