Unverified 提交 617964ff authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Refactor and update QR Op (#1518)

* Refactor QR * Update JAX QR dispatch * Update Torch QR dispatch * Update numba QR dispatch
上级 5024d54e
...@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import ( ...@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
KroneckerProduct, KroneckerProduct,
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
QRFull,
SLogDet, SLogDet,
) )
...@@ -67,16 +66,6 @@ def jax_funcify_MatrixInverse(op, **kwargs): ...@@ -67,16 +66,6 @@ def jax_funcify_MatrixInverse(op, **kwargs):
return matrix_inverse return matrix_inverse
@jax_funcify.register(QRFull)
def jax_funcify_QRFull(op, **kwargs):
mode = op.mode
def qr_full(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)
return qr_full
@jax_funcify.register(MatrixPinv) @jax_funcify.register(MatrixPinv)
def jax_funcify_Pinv(op, **kwargs): def jax_funcify_Pinv(op, **kwargs):
def pinv(x): def pinv(x):
......
...@@ -5,6 +5,7 @@ import jax ...@@ -5,6 +5,7 @@ import jax
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
...@@ -168,3 +169,13 @@ def jax_funcify_ChoSolve(op, **kwargs): ...@@ -168,3 +169,13 @@ def jax_funcify_ChoSolve(op, **kwargs):
) )
return cho_solve return cho_solve
@jax_funcify.register(QR)
def jax_funcify_QR(op, **kwargs):
mode = op.mode
def qr(x, mode=mode):
return jax.scipy.linalg.qr(x, mode=mode)
return qr
...@@ -283,7 +283,6 @@ class _LAPACK: ...@@ -283,7 +283,6 @@ class _LAPACK:
Called by scipy.linalg.lu_solve Called by scipy.linalg.lu_solve
""" """
...
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs") lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
functype = ctypes.CFUNCTYPE( functype = ctypes.CFUNCTYPE(
None, None,
...@@ -457,3 +456,90 @@ class _LAPACK: ...@@ -457,3 +456,90 @@ class _LAPACK:
_ptr_int, # INFO _ptr_int, # INFO
) )
return functype(lapack_ptr) return functype(lapack_ptr)
@classmethod
def numba_xgeqrf(cls, dtype):
"""
Compute the QR factorization of a general M-by-N matrix A.
Used in QR decomposition (no pivoting).
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xgeqp3(cls, dtype):
"""
Compute the QR factorization with column pivoting of a general M-by-N matrix A.
Used in QR decomposition with pivoting.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # JPVT
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xorgqr(cls, dtype):
"""
Generate the orthogonal matrix Q from a QR factorization (real types).
Used in QR decomposition to form Q.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "orgqr")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
_ptr_int, # K
float_pointer, # A
_ptr_int, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xungqr(cls, dtype):
"""
Generate the unitary matrix Q from a QR factorization (complex types).
Used in QR decomposition to form Q for complex types.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "ungqr")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
_ptr_int, # K
float_pointer, # A
_ptr_int, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy.linalg import get_lapack_funcs, qr
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
(geqrf,) = get_lapack_funcs(("geqrf",), (A,))
return geqrf(A, overwrite_a=overwrite_a, lwork=lwork)
@overload(_xgeqrf)
def xgeqrf_impl(A, overwrite_a, lwork):
ensure_lapack()
dtype = A.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
def impl(A, overwrite_a, lwork):
M = np.int32(A.shape[0])
N = np.int32(A.shape[1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
LDA = val_to_int_ptr(M)
TAU = np.empty(min(M, N), dtype=dtype)
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
else:
WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype)
LWORK = val_to_int_ptr(WORK.size)
INFO = val_to_int_ptr(1)
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
A_copy.view(w_type).ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
return A_copy, TAU, WORK, int_ptr_to_val(INFO)
return impl
def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
(geqp3,) = get_lapack_funcs(("geqp3",), (A,))
return geqp3(A, overwrite_a=overwrite_a, lwork=lwork)
@overload(_xgeqp3)
def xgeqp3_impl(A, overwrite_a, lwork):
ensure_lapack()
dtype = A.dtype
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype)
def impl(A, overwrite_a, lwork):
M = np.int32(A.shape[0])
N = np.int32(A.shape[1])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
LDA = val_to_int_ptr(M)
JPVT = np.zeros(N, dtype=np.int32)
TAU = np.empty(min(M, N), dtype=dtype)
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
else:
WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype)
LWORK = val_to_int_ptr(WORK.size)
INFO = val_to_int_ptr(1)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
A_copy.view(w_type).ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
return A_copy, JPVT, TAU, WORK, int_ptr_to_val(INFO)
return impl
def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
(orgqr,) = get_lapack_funcs(("orgqr",), (A,))
return orgqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)
@overload(_xorgqr)
def xorgqr_impl(A, tau, overwrite_a, lwork):
ensure_lapack()
dtype = A.dtype
w_type = _get_underlying_float(dtype)
orgqr = _LAPACK().numba_xorgqr(dtype)
def impl(A, tau, overwrite_a, lwork):
M = np.int32(A.shape[0])
N = np.int32(A.shape[1])
K = np.int32(tau.shape[0])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
else:
WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype)
LWORK = val_to_int_ptr(WORK.size)
LDA = val_to_int_ptr(M)
INFO = val_to_int_ptr(1)
orgqr(
val_to_int_ptr(M),
val_to_int_ptr(N),
val_to_int_ptr(K),
A_copy.view(w_type).ctypes,
LDA,
tau.view(w_type).ctypes,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
return A_copy, WORK, int_ptr_to_val(INFO)
return impl
def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
"""LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
(ungqr,) = get_lapack_funcs(("ungqr",), (A,))
return ungqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)
@overload(_xungqr)
def xungqr_impl(A, tau, overwrite_a, lwork):
ensure_lapack()
dtype = A.dtype
w_type = _get_underlying_float(dtype)
ungqr = _LAPACK().numba_xungqr(dtype)
def impl(A, tau, overwrite_a, lwork):
M = np.int32(A.shape[0])
N = np.int32(A.shape[1])
K = np.int32(tau.shape[0])
if overwrite_a and A.flags.f_contiguous:
A_copy = A
else:
A_copy = _copy_to_fortran_order(A)
LDA = val_to_int_ptr(M)
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
LWORK = val_to_int_ptr(-1)
else:
WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype)
LWORK = val_to_int_ptr(WORK.size)
INFO = val_to_int_ptr(1)
ungqr(
val_to_int_ptr(M),
val_to_int_ptr(N),
val_to_int_ptr(K),
A_copy.view(w_type).ctypes,
LDA,
tau.view(w_type).ctypes,
WORK.view(w_type).ctypes,
LWORK,
INFO,
)
return A_copy, WORK, int_ptr_to_val(INFO)
return impl
def _qr_full_pivot(
x: np.ndarray,
mode: str = "full",
pivoting: bool = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode not "r" or "raw", and pivoting is True, resulting in a return of arrays Q, R, and
P.
"""
return qr(
x,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
lwork=lwork,
)
def _qr_full_no_pivot(
x: np.ndarray,
mode: str = "full",
pivoting: bool = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode not "r" or "raw", and pivoting is False, resulting in a return of arrays Q and R.
"""
return qr(
x,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
lwork=lwork,
)
def _qr_r_pivot(
x: np.ndarray,
mode: str = "r",
pivoting: bool = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode is "r" or "raw", and pivoting is True, resulting in a return of arrays R and P.
"""
return qr(
x,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
lwork=lwork,
)
def _qr_r_no_pivot(
x: np.ndarray,
mode: str = "r",
pivoting: bool = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode is "r" or "raw", and pivoting is False, resulting in a return of array R.
"""
return qr(
x,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
lwork=lwork,
)
def _qr_raw_no_pivot(
x: np.ndarray,
mode: str = "raw",
pivoting: bool = False,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode is "raw", and pivoting is False, resulting in a return of arrays H, tau, and R.
"""
(H, tau), R = qr(
x,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
lwork=lwork,
)
return H, tau, R
def _qr_raw_pivot(
x: np.ndarray,
mode: str = "raw",
pivoting: bool = True,
overwrite_a: bool = False,
check_finite: bool = False,
lwork: int | None = None,
):
"""
Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same
script.
Corresponds to the case where mode is "raw", and pivoting is True, resulting in a return of arrays H, tau, R, and P.
"""
(H, tau), R, P = qr(
x,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
lwork=lwork,
)
return H, tau, R, P
@overload(_qr_full_pivot)
def qr_full_pivot_impl(
x, mode="full", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype)
orgqr = _LAPACK().numba_xorgqr(dtype)
def impl(
x,
mode="full",
pivoting=True,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
N = np.int32(x.shape[1])
K = min(M, N)
if overwrite_a and x.flags.f_contiguous:
x_copy = x
else:
x_copy = _copy_to_fortran_order(x)
LDA = val_to_int_ptr(M)
TAU = np.empty(K, dtype=dtype)
JPVT = np.zeros(N, dtype=np.int32)
if lwork is None:
lwork = -1
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item())
else:
lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
JPVT = (JPVT - 1).astype(np.int32)
if mode == "full" or M < N:
R = np.triu(x_copy)
else:
R = np.triu(x_copy[:N, :])
if M < N:
Q_in = x_copy[:, :M]
elif M == N or mode == "economic":
Q_in = x_copy
else:
# Transpose to put the matrix into Fortran order
Q_in = np.empty((M, M), dtype=dtype).T
Q_in[:, :N] = x_copy
if lwork == -1:
WORKQ = np.empty(1, dtype=dtype)
orgqr(
val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K),
Q_in.view(w_type).ctypes,
val_to_int_ptr(M),
TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_q = int(WORKQ.item())
else:
lwork_q = lwork
WORKQ = np.empty(lwork_q, dtype=dtype)
INFOQ = val_to_int_ptr(1)
orgqr(
val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K),
Q_in.view(w_type).ctypes,
val_to_int_ptr(M),
TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes,
val_to_int_ptr(lwork_q),
INFOQ,
)
return Q_in, R, JPVT
return impl
@overload(_qr_full_no_pivot)
def qr_full_no_pivot_impl(
x, mode="full", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
orgqr = _LAPACK().numba_xorgqr(dtype)
def impl(
x,
mode="full",
pivoting=False,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
N = np.int32(x.shape[1])
K = min(M, N)
if overwrite_a and x.flags.f_contiguous:
x_copy = x
else:
x_copy = _copy_to_fortran_order(x)
LDA = val_to_int_ptr(M)
TAU = np.empty(K, dtype=dtype)
if lwork is None:
lwork = -1
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item())
else:
lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1)
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
if M < N or mode == "full":
R = np.triu(x_copy)
else:
R = np.triu(x_copy[:N, :])
if M < N:
Q_in = x_copy[:, :M]
elif M == N or mode == "economic":
Q_in = x_copy
else:
# Transpose to put the matrix into Fortran order
Q_in = np.empty((M, M), dtype=dtype).T
Q_in[:, :N] = x_copy
if lwork == -1:
WORKQ = np.empty(1, dtype=dtype)
orgqr(
val_to_int_ptr(M),
val_to_int_ptr(Q_in.shape[1]),
val_to_int_ptr(K),
Q_in.view(w_type).ctypes,
val_to_int_ptr(M),
TAU.view(w_type).ctypes,
WORKQ.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_q = int(WORKQ.item())
else:
lwork_q = lwork
WORKQ = np.empty(lwork_q, dtype=dtype)
INFOQ = val_to_int_ptr(1)
orgqr(
val_to_int_ptr(M), # M
val_to_int_ptr(Q_in.shape[1]), # N
val_to_int_ptr(K), # K
Q_in.view(w_type).ctypes, # A
val_to_int_ptr(M), # LDA
TAU.view(w_type).ctypes, # TAU
WORKQ.view(w_type).ctypes, # WORK
val_to_int_ptr(lwork_q), # LWORK
INFOQ, # INFO
)
return Q_in, R
return impl
@overload(_qr_r_pivot)
def qr_r_pivot_impl(
x, mode="r", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype)
def impl(
x,
mode="r",
pivoting=True,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
N = np.int32(x.shape[1])
if overwrite_a and x.flags.f_contiguous:
x_copy = x
else:
x_copy = _copy_to_fortran_order(x)
LDA = val_to_int_ptr(M)
K = min(M, N)
TAU = np.empty(K, dtype=dtype)
JPVT = np.zeros(N, dtype=np.int32)
if lwork is None:
lwork = -1
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item())
else:
lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
JPVT = (JPVT - 1).astype(np.int32)
if M < N:
R = np.triu(x_copy)
else:
R = np.triu(x_copy[:N, :])
return R, JPVT
return impl
@overload(_qr_r_no_pivot)
def qr_r_no_pivot_impl(
x, mode="r", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
def impl(
x,
mode="r",
pivoting=False,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
N = np.int32(x.shape[1])
if overwrite_a and x.flags.f_contiguous:
x_copy = x
else:
x_copy = _copy_to_fortran_order(x)
LDA = val_to_int_ptr(M)
K = min(M, N)
TAU = np.empty(K, dtype=dtype)
if lwork is None:
lwork = -1
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item())
else:
lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1)
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
if M < N:
R = np.triu(x_copy)
else:
R = np.triu(x_copy[:N, :])
# Return a tuple with R only to match the scipy qr interface
return (R,)
return impl
@overload(_qr_raw_no_pivot)
def qr_raw_no_pivot_impl(
x, mode="raw", pivoting=False, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqrf = _LAPACK().numba_xgeqrf(dtype)
def impl(
x,
mode="raw",
pivoting=False,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
N = np.int32(x.shape[1])
if overwrite_a and x.flags.f_contiguous:
x_copy = x
else:
x_copy = _copy_to_fortran_order(x)
LDA = val_to_int_ptr(M)
K = min(M, N)
TAU = np.empty(K, dtype=dtype)
if lwork is None:
lwork = -1
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item())
else:
lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1)
geqrf(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
if M < N:
R = np.triu(x_copy)
else:
R = np.triu(x_copy[:N, :])
return x_copy, TAU, R
return impl
@overload(_qr_raw_pivot)
def qr_raw_pivot_impl(
x, mode="raw", pivoting=True, overwrite_a=False, check_finite=False, lwork=None
):
ensure_lapack()
dtype = x.dtype
w_type = _get_underlying_float(dtype)
geqp3 = _LAPACK().numba_xgeqp3(dtype)
def impl(
x,
mode="raw",
pivoting=True,
overwrite_a=False,
check_finite=False,
lwork=None,
):
M = np.int32(x.shape[0])
N = np.int32(x.shape[1])
if overwrite_a and x.flags.f_contiguous:
x_copy = x
else:
x_copy = _copy_to_fortran_order(x)
LDA = val_to_int_ptr(M)
K = min(M, N)
TAU = np.empty(K, dtype=dtype)
JPVT = np.zeros(N, dtype=np.int32)
if lwork is None:
lwork = -1
if lwork == -1:
WORK = np.empty(1, dtype=dtype)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(-1),
val_to_int_ptr(1),
)
lwork_val = int(WORK.item())
else:
lwork_val = lwork
WORK = np.empty(lwork_val, dtype=dtype)
INFO = val_to_int_ptr(1)
geqp3(
val_to_int_ptr(M),
val_to_int_ptr(N),
x_copy.view(w_type).ctypes,
LDA,
JPVT.ctypes,
TAU.view(w_type).ctypes,
WORK.view(w_type).ctypes,
val_to_int_ptr(lwork_val),
INFO,
)
JPVT = (JPVT - 1).astype(np.int32)
if M < N:
R = np.triu(x_copy)
else:
R = np.triu(x_copy[:N, :])
return x_copy, TAU, R, JPVT
return impl
...@@ -16,7 +16,6 @@ from pytensor.tensor.nlinalg import ( ...@@ -16,7 +16,6 @@ from pytensor.tensor.nlinalg import (
Eigh, Eigh,
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
QRFull,
SLogDet, SLogDet,
) )
...@@ -146,38 +145,3 @@ def numba_funcify_MatrixPinv(op, node, **kwargs): ...@@ -146,38 +145,3 @@ def numba_funcify_MatrixPinv(op, node, **kwargs):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype) return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
return matrixpinv return matrixpinv
@numba_funcify.register(QRFull)
def numba_funcify_QRFull(op, node, **kwargs):
mode = op.mode
if mode != "reduced":
warnings.warn(
(
"Numba will use object mode to allow the "
"`mode` argument to `numpy.linalg.qr`."
),
UserWarning,
)
if len(node.outputs) > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba_basic.numba_njit
def qr_full(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.qr(x, mode=mode)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always")
def qr_full(x):
return np.linalg.qr(inputs_cast(x))
return qr_full
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
import numpy as np import numpy as np
from pytensor import config
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
from pytensor.link.numba.dispatch.linalg.decomposition.lu import ( from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
...@@ -11,6 +12,14 @@ from pytensor.link.numba.dispatch.linalg.decomposition.lu import ( ...@@ -11,6 +12,14 @@ from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
_pivot_to_permutation, _pivot_to_permutation,
) )
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor
from pytensor.link.numba.dispatch.linalg.decomposition.qr import (
_qr_full_no_pivot,
_qr_full_pivot,
_qr_r_no_pivot,
_qr_r_pivot,
_qr_raw_no_pivot,
_qr_raw_pivot,
)
from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
...@@ -19,6 +28,7 @@ from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangul ...@@ -19,6 +28,7 @@ from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangul
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
...@@ -27,7 +37,7 @@ from pytensor.tensor.slinalg import ( ...@@ -27,7 +37,7 @@ from pytensor.tensor.slinalg import (
Solve, Solve,
SolveTriangular, SolveTriangular,
) )
from pytensor.tensor.type import complex_dtypes from pytensor.tensor.type import complex_dtypes, integer_dtypes
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
...@@ -311,3 +321,96 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -311,3 +321,96 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
) )
return cho_solve return cho_solve
@numba_funcify.register(QR)
def numba_funcify_QR(op, node, **kwargs):
mode = op.mode
check_finite = op.check_finite
pivoting = op.pivoting
overwrite_a = op.overwrite_a
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
integer_input = dtype in integer_dtypes
in_dtype = config.floatX if integer_input else dtype
@numba_njit(cache=False)
def qr(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to qr"
)
if integer_input:
a = a.astype(in_dtype)
if (mode == "full" or mode == "economic") and pivoting:
Q, R, P = _qr_full_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return Q, R, P
elif (mode == "full" or mode == "economic") and not pivoting:
Q, R = _qr_full_no_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return Q, R
elif mode == "r" and pivoting:
R, P = _qr_r_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return R, P
elif mode == "r" and not pivoting:
(R,) = _qr_r_no_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return R
elif mode == "raw" and pivoting:
H, tau, R, P = _qr_raw_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return H, tau, R, P
elif mode == "raw" and not pivoting:
H, tau, R = _qr_raw_no_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return H, tau, R
else:
raise NotImplementedError(
f"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
)
return qr
...@@ -8,6 +8,7 @@ import pytensor.link.pytorch.dispatch.elemwise ...@@ -8,6 +8,7 @@ import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.math import pytensor.link.pytorch.dispatch.math
import pytensor.link.pytorch.dispatch.extra_ops import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.nlinalg import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.slinalg
import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.subtensor import pytensor.link.pytorch.dispatch.subtensor
......
...@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import ( ...@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
KroneckerProduct, KroneckerProduct,
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
QRFull,
SLogDet, SLogDet,
) )
...@@ -70,21 +69,6 @@ def pytorch_funcify_MatrixInverse(op, **kwargs): ...@@ -70,21 +69,6 @@ def pytorch_funcify_MatrixInverse(op, **kwargs):
return matrix_inverse return matrix_inverse
@pytorch_funcify.register(QRFull)
def pytorch_funcify_QRFull(op, **kwargs):
mode = op.mode
if mode == "raw":
raise NotImplementedError("raw mode not implemented in PyTorch")
def qr_full(x):
Q, R = torch.linalg.qr(x, mode=mode)
if mode == "r":
return R
return Q, R
return qr_full
@pytorch_funcify.register(MatrixPinv) @pytorch_funcify.register(MatrixPinv)
def pytorch_funcify_Pinv(op, **kwargs): def pytorch_funcify_Pinv(op, **kwargs):
hermitian = op.hermitian hermitian = op.hermitian
......
import torch
from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.slinalg import QR
@pytorch_funcify.register(QR)
def pytorch_funcify_QR(op, **kwargs):
mode = op.mode
if mode == "raw":
raise NotImplementedError("raw mode not implemented in PyTorch")
elif mode == "full":
mode = "complete"
elif mode == "economic":
mode = "reduced"
def qr(x):
Q, R = torch.linalg.qr(x, mode=mode)
if mode == "r":
return R
return Q, R
return qr
...@@ -5,15 +5,12 @@ from typing import Literal, cast ...@@ -5,15 +5,12 @@ from typing import Literal, cast
import numpy as np import numpy as np
import pytensor.tensor as pt
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.raise_op import Assert
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
...@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"): ...@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"):
return Eigh(UPLO)(a) return Eigh(UPLO)(a)
class QRFull(Op):
"""
Full QR Decomposition.
Computes the QR decomposition of a matrix.
Factor the matrix a as qr, where q is orthonormal
and r is upper-triangular.
"""
__props__ = ("mode",)
def __init__(self, mode):
self.mode = mode
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of qr function should be a matrix."
in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
q = matrix(dtype=out_dtype)
if self.mode != "raw":
r = matrix(dtype=out_dtype)
else:
r = vector(dtype=out_dtype)
if self.mode != "r":
q = matrix(dtype=out_dtype)
outputs = [q, r]
else:
outputs = [r]
return Apply(self, [x], outputs)
def perform(self, node, inputs, outputs):
(x,) = inputs
assert x.ndim == 2, "The input of qr function should be a matrix."
res = np.linalg.qr(x, self.mode)
if self.mode != "r":
outputs[0][0], outputs[1][0] = res
else:
outputs[0][0] = res
def L_op(self, inputs, outputs, output_grads):
"""
Reverse-mode gradient of the QR function.
References
----------
.. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
.. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
"""
from pytensor.tensor.slinalg import solve_triangular
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
m, n = A.shape
def _H(x: ptb.TensorVariable):
return x.conj().mT
def _copyltu(x: ptb.TensorVariable):
return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1))
if self.mode == "raw":
raise NotImplementedError("Gradient of qr not implemented for mode=raw")
elif self.mode == "r":
# We need all the components of the QR to compute the gradient of A even if we only
# use the upper triangular component in the cost function.
Q, R = qr(A, mode="reduced")
dQ = Q.zeros_like()
dR = cast(ptb.TensorVariable, output_grads[0])
else:
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
if self.mode == "complete":
qr_assert_op = Assert(
"Gradient of qr not implemented for m x n matrices with m > n and mode=complete"
)
R = qr_assert_op(R, ptm.le(m, n))
new_output_grads = []
is_disconnected = [
isinstance(x.type, DisconnectedType) for x in output_grads
]
if all(is_disconnected):
# This should never be reached by Pytensor
return [DisconnectedType()()] # pragma: no cover
for disconnected, output_grad, output in zip(
is_disconnected, output_grads, [Q, R], strict=True
):
if disconnected:
new_output_grads.append(output.zeros_like())
else:
new_output_grads.append(output_grad)
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
# gradient expression when m >= n
M = R @ _H(dR) - _H(dQ) @ Q
K = dQ + Q @ _copyltu(M)
A_bar_m_ge_n = _H(solve_triangular(R, _H(K)))
# gradient expression when m < n
Y = A[:, m:]
U = R[:, :m]
dU, dV = dR[:, :m], dR[:, m:]
dQ_Yt_dV = dQ + Y @ _H(dV)
M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q
X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M))))
Y_bar = Q @ dV
A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1)
return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)]
def qr(a, mode="reduced"):
"""
Computes the QR decomposition of a matrix.
Factor the matrix a as qr, where q
is orthonormal and r is upper-triangular.
Parameters
----------
a : array_like, shape (M, N)
Matrix to be factored.
mode : {'reduced', 'complete', 'r', 'raw'}, optional
If K = min(M, N), then
'reduced'
returns q, r with dimensions (M, K), (K, N)
'complete'
returns q, r with dimensions (M, M), (M, N)
'r'
returns r only with dimensions (K, N)
'raw'
returns h, tau with dimensions (N, M), (K,)
Note that array h returned in 'raw' mode is
transposed for calling Fortran.
Default mode is 'reduced'
Returns
-------
q : matrix of float or complex, optional
A matrix with orthonormal columns. When mode = 'complete' the
result is an orthogonal/unitary matrix depending on whether or
not a is real/complex. The determinant may be either +/- 1 in
that case.
r : matrix of float or complex, optional
The upper-triangular matrix.
"""
return QRFull(mode)(a)
class SVD(Op): class SVD(Op):
""" """
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
...@@ -1291,7 +1121,6 @@ __all__ = [ ...@@ -1291,7 +1121,6 @@ __all__ = [
"det", "det",
"eig", "eig",
"eigh", "eigh",
"qr",
"svd", "svd",
"lstsq", "lstsq",
"matrix_power", "matrix_power",
......
...@@ -7,16 +7,19 @@ from typing import Literal, cast ...@@ -7,16 +7,19 @@ from typing import Literal, cast
import numpy as np import numpy as np
import scipy.linalg as scipy_linalg import scipy.linalg as scipy_linalg
from numpy.exceptions import ComplexWarning from numpy.exceptions import ComplexWarning
from scipy.linalg import get_lapack_funcs
import pytensor import pytensor
import pytensor.tensor as pt from pytensor import ifelse
from pytensor import tensor as pt
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.raise_op import Assert
from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
from pytensor.tensor.basic import diagonal from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.nlinalg import kron, matrix_dot
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
...@@ -1714,6 +1717,376 @@ def block_diag(*matrices: TensorVariable): ...@@ -1714,6 +1717,376 @@ def block_diag(*matrices: TensorVariable):
return _block_diagonal_matrix(*matrices) return _block_diagonal_matrix(*matrices)
class QR(Op):
"""
QR Decomposition
"""
__props__ = (
"overwrite_a",
"mode",
"pivoting",
"check_finite",
)
def __init__(
self,
mode: Literal["full", "r", "economic", "raw"] = "full",
overwrite_a: bool = False,
pivoting: bool = False,
check_finite: bool = False,
):
self.mode = mode
self.overwrite_a = overwrite_a
self.pivoting = pivoting
self.check_finite = check_finite
self.destroy_map = {}
if overwrite_a:
self.destroy_map = {0: [0]}
match self.mode:
case "economic":
self.gufunc_signature = "(m,n)->(m,k),(k,n)"
case "full":
self.gufunc_signature = "(m,n)->(m,m),(m,n)"
case "r":
self.gufunc_signature = "(m,n)->(m,n)"
case "raw":
self.gufunc_signature = "(m,n)->(n,m),(k),(m,n)"
case _:
raise ValueError(
f"Invalid mode '{mode}'. Supported modes are 'full', 'economic', 'r', and 'raw'."
)
if pivoting:
self.gufunc_signature += ",(n)"
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of qr function should be a matrix."
# Preserve static shape information if possible
M, N = x.type.shape
if M is not None and N is not None:
K = min(M, N)
else:
K = None
in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
match self.mode:
case "full":
outputs = [
tensor(shape=(M, M), dtype=out_dtype),
tensor(shape=(M, N), dtype=out_dtype),
]
case "economic":
outputs = [
tensor(shape=(M, K), dtype=out_dtype),
tensor(shape=(K, N), dtype=out_dtype),
]
case "r":
outputs = [
tensor(shape=(M, N), dtype=out_dtype),
]
case "raw":
outputs = [
tensor(shape=(M, M), dtype=out_dtype),
tensor(shape=(K,), dtype=out_dtype),
tensor(shape=(M, N), dtype=out_dtype),
]
case _:
raise NotImplementedError
if self.pivoting:
outputs = [*outputs, tensor(shape=(N,), dtype="int32")]
return Apply(self, [x], outputs)
def infer_shape(self, fgraph, node, shapes):
(x_shape,) = shapes
M, N = x_shape
K = ptm.minimum(M, N)
Q_shape = None
R_shape = None
tau_shape = None
P_shape = None
match self.mode:
case "full":
Q_shape = (M, M)
R_shape = (M, N)
case "economic":
Q_shape = (M, K)
R_shape = (K, N)
case "r":
R_shape = (M, N)
case "raw":
Q_shape = (M, M) # Actually this is H in this case
tau_shape = (K,)
R_shape = (M, N)
if self.pivoting:
P_shape = (N,)
return [
shape
for shape in (Q_shape, tau_shape, R_shape, P_shape)
if shape is not None
]
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if not allowed_inplace_inputs:
return self
new_props = self._props_dict() # type: ignore
new_props["overwrite_a"] = True
return type(self)(**new_props)
def _call_and_get_lwork(self, fn, *args, lwork, **kwargs):
if lwork in [-1, None]:
*_, work, info = fn(*args, lwork=-1, **kwargs)
lwork = work.item()
return fn(*args, lwork=lwork, **kwargs)
def perform(self, node, inputs, outputs):
(x,) = inputs
M, N = x.shape
if self.pivoting:
(geqp3,) = get_lapack_funcs(("geqp3",), (x,))
qr, jpvt, tau, *work_info = self._call_and_get_lwork(
geqp3, x, lwork=-1, overwrite_a=self.overwrite_a
)
jpvt -= 1 # geqp3 returns a 1-based index array, so subtract 1
else:
(geqrf,) = get_lapack_funcs(("geqrf",), (x,))
qr, tau, *work_info = self._call_and_get_lwork(
geqrf, x, lwork=-1, overwrite_a=self.overwrite_a
)
if self.mode not in ["economic", "raw"] or M < N:
R = np.triu(qr)
else:
R = np.triu(qr[:N, :])
if self.mode == "r" and self.pivoting:
outputs[0][0] = R
outputs[1][0] = jpvt
return
elif self.mode == "r":
outputs[0][0] = R
return
elif self.mode == "raw" and self.pivoting:
outputs[0][0] = qr
outputs[1][0] = tau
outputs[2][0] = R
outputs[3][0] = jpvt
return
elif self.mode == "raw":
outputs[0][0] = qr
outputs[1][0] = tau
outputs[2][0] = R
return
(gor_un_gqr,) = get_lapack_funcs(("orgqr",), (qr,))
if M < N:
Q, work, info = self._call_and_get_lwork(
gor_un_gqr, qr[:, :M], tau, lwork=-1, overwrite_a=1
)
elif self.mode == "economic":
Q, work, info = self._call_and_get_lwork(
gor_un_gqr, qr, tau, lwork=-1, overwrite_a=1
)
else:
t = qr.dtype.char
qqr = np.empty((M, M), dtype=t)
qqr[:, :N] = qr
# Always overwite qqr -- it's a meaningless intermediate value
Q, work, info = self._call_and_get_lwork(
gor_un_gqr, qqr, tau, lwork=-1, overwrite_a=1
)
outputs[0][0] = Q
outputs[1][0] = R
if self.pivoting:
outputs[2][0] = jpvt
def L_op(self, inputs, outputs, output_grads):
"""
Reverse-mode gradient of the QR function.
References
----------
.. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
.. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
"""
from pytensor.tensor.slinalg import solve_triangular
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
m, n = A.shape
# Check if we have static shape info, if so we can get a better graph (avoiding the ifelse Op in the output)
M_static, N_static = A.type.shape
shapes_unknown = M_static is None or N_static is None
def _H(x: ptb.TensorVariable):
return x.conj().mT
def _copyltu(x: ptb.TensorVariable):
return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1))
if self.mode == "raw":
raise NotImplementedError("Gradient of qr not implemented for mode=raw")
elif self.mode == "r":
k = pt.minimum(m, n)
# We need all the components of the QR to compute the gradient of A even if we only
# use the upper triangular component in the cost function.
props_dict = self._props_dict()
props_dict["mode"] = "economic"
props_dict["pivoting"] = False
qr_op = type(self)(**props_dict)
Q, R = qr_op(A)
dQ = Q.zeros_like()
# Unlike numpy.linalg.qr, scipy.linalg.qr returns the full (m,n) matrix when mode='r', *not* the (k,n)
# matrix that is computed by mode='economic'. The gradient assumes that dR is of shape (k,n), so we need to
# slice it to the first k rows. Note that if m <= n, then k = m, so this is safe in all cases.
dR = cast(ptb.TensorVariable, output_grads[0][:k, :])
else:
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
if self.mode == "full":
qr_assert_op = Assert(
"Gradient of qr not implemented for m x n matrices with m > n and mode=full"
)
R = qr_assert_op(R, ptm.le(m, n))
new_output_grads = []
is_disconnected = [
isinstance(x.type, DisconnectedType) for x in output_grads
]
if all(is_disconnected):
# This should never be reached by Pytensor
return [DisconnectedType()()] # pragma: no cover
for disconnected, output_grad, output in zip(
is_disconnected, output_grads, [Q, R], strict=True
):
if disconnected:
new_output_grads.append(output.zeros_like())
else:
new_output_grads.append(output_grad)
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
if shapes_unknown or M_static >= N_static:
# gradient expression when m >= n
M = R @ _H(dR) - _H(dQ) @ Q
K = dQ + Q @ _copyltu(M)
A_bar_m_ge_n = _H(solve_triangular(R, _H(K)))
if not shapes_unknown:
return [A_bar_m_ge_n]
# We have to trigger both branches if shapes_unknown is True, so this is purposefully not an elif branch
if shapes_unknown or M_static < N_static:
# gradient expression when m < n
Y = A[:, m:]
U = R[:, :m]
dU, dV = dR[:, :m], dR[:, m:]
dQ_Yt_dV = dQ + Y @ _H(dV)
M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q
X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M))))
Y_bar = Q @ dV
A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1)
if not shapes_unknown:
return [A_bar_m_lt_n]
return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)]
def qr(
A: TensorLike,
mode: Literal["full", "r", "economic", "raw", "complete", "reduced"] = "full",
overwrite_a: bool = False,
pivoting: bool = False,
lwork: int | None = None,
):
"""
QR Decomposition of input matrix `a`.
The QR decomposition of a matrix `A` is a factorization of the form :math`A = QR`, where `Q` is an orthogonal
matrix (:math:`Q Q^T = I`) and `R` is an upper triangular matrix.
This decomposition is useful in various numerical methods, including solving linear systems and least squares
problems.
Parameters
----------
A: TensorLike
Input matrix of shape (M, N) to be decomposed.
mode: str, one of "full", "economic", "r", or "raw"
How the QR decomposition is computed and returned. Choosing the mode can avoid unnecessary computations,
depending on which of the return matrices are needed. Given input matrix with shape Choices are:
- "full" (or "complete"): returns `Q` and `R` with dimensions `(M, M)` and `(M, N)`.
- "economic" (or "reduced"): returns `Q` and `R` with dimensions `(M, K)` and `(K, N)`,
where `K = min(M, N)`.
- "r": returns only `R` with dimensions `(K, N)`.
- "raw": returns `H` and `tau` with dimensions `(N, M)` and `(K,)`, where `H` is the matrix of
Householder reflections, and tau is the vector of Householder coefficients.
pivoting: bool, default False
If True, also return a vector of rank-revealing permutations `P` such that `A[:, P] = QR`.
overwrite_a: bool, ignored
Ignored. Included only for consistency with the function signature of `scipy.linalg.qr`. Pytensor will always
automatically overwrite the input matrix `A` if it is safe to do sol.
lwork: int, ignored
Ignored. Included only for consistency with the function signature of `scipy.linalg.qr`. Pytensor will
automatically determine the optimal workspace size for the QR decomposition.
Returns
-------
Q or H: TensorVariable, optional
A matrix with orthonormal columns. When mode = 'complete', it is the result is an orthogonal/unitary matrix
depending on whether a is real/complex. The determinant may be either +/- 1 in that case. If
mode = 'raw', it is the matrix of Householder reflections. If mode = 'r', Q is not returned.
R or tau : TensorVariable, optional
Upper-triangular matrix. If mode = 'raw', it is the vector of Householder coefficients.
"""
# backwards compatibility from the numpy API
if mode == "complete":
mode = "full"
elif mode == "reduced":
mode = "economic"
return Blockwise(QR(mode=mode, pivoting=pivoting, overwrite_a=False))(A)
__all__ = [ __all__ = [
"cholesky", "cholesky",
"solve", "solve",
...@@ -1728,4 +2101,5 @@ __all__ = [ ...@@ -1728,4 +2101,5 @@ __all__ = [
"lu", "lu",
"lu_factor", "lu_factor",
"lu_solve", "lu_solve",
"qr",
] ]
...@@ -29,12 +29,6 @@ def test_jax_basic_multiout(): ...@@ -29,12 +29,6 @@ def test_jax_basic_multiout():
outs = pt_nlinalg.eigh(x) outs = pt_nlinalg.eigh(x)
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.qr(x, mode="full")
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.qr(x, mode="reduced")
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.svd(x) outs = pt_nlinalg.svd(x)
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
......
...@@ -103,6 +103,18 @@ def test_jax_basic(): ...@@ -103,6 +103,18 @@ def test_jax_basic():
], ],
) )
def assert_fn(x, y):
np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3)
M = rng.normal(size=(3, 3))
X = M.dot(M.T)
outs = pt_slinalg.qr(x, mode="full")
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_slinalg.qr(x, mode="economic")
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
def test_jax_solve(): def test_jax_solve():
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
...@@ -186,60 +186,6 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -186,60 +186,6 @@ def test_matrix_inverses(op, x, exc, op_args):
) )
@pytest.mark.parametrize(
"x, mode, exc",
[
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
"reduced",
None,
),
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
"r",
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
"reduced",
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
"complete",
UserWarning,
),
],
)
def test_QRFull(x, mode, exc):
x, test_x = x
g = nlinalg.QRFull(mode)(x)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
g,
[test_x],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, full_matrices, compute_uv, exc", "x, full_matrices, compute_uv, exc",
[ [
......
...@@ -10,6 +10,7 @@ import pytensor.tensor as pt ...@@ -10,6 +10,7 @@ import pytensor.tensor as pt
from pytensor import In, config from pytensor import In, config
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
LUFactor, LUFactor,
...@@ -720,3 +721,70 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo ...@@ -720,3 +721,70 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
# Can never destroy non-contiguous inputs # Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val) np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize(
"mode, pivoting",
[("economic", False), ("full", True), ("r", False), ("raw", True)],
ids=["economic", "full_pivot", "r", "raw_pivot"],
)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_qr(mode, pivoting, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting)
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
qr_outputs,
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, QR)
destroy_map = op.destroy_map
if overwrite_a:
assert destroy_map == {0: [0]}
else:
assert destroy_map == {}
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
import numpy as np
import pytest
from pytensor import config
from pytensor.tensor.type import matrix
@pytest.fixture
def matrix_test():
rng = np.random.default_rng(213234)
M = rng.normal(size=(3, 3))
test_value = M.dot(M.T).astype(config.floatX)
x = matrix("x")
return x, test_value
...@@ -8,17 +8,6 @@ from pytensor.tensor.type import matrix ...@@ -8,17 +8,6 @@ from pytensor.tensor.type import matrix
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
@pytest.fixture
def matrix_test():
rng = np.random.default_rng(213234)
M = rng.normal(size=(3, 3))
test_value = M.dot(M.T).astype(config.floatX)
x = matrix("x")
return (x, test_value)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"func", "func",
(pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det), (pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
...@@ -34,22 +23,6 @@ def test_lin_alg_no_params(func, matrix_test): ...@@ -34,22 +23,6 @@ def test_lin_alg_no_params(func, matrix_test):
compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn) compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn)
@pytest.mark.parametrize(
"mode",
(
"complete",
"reduced",
"r",
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)),
),
)
def test_qr(mode, matrix_test):
x, test_value = matrix_test
outs = pt_nla.qr(x, mode=mode)
compare_pytorch_and_py([x], outs, [test_value])
@pytest.mark.parametrize("compute_uv", [True, False]) @pytest.mark.parametrize("compute_uv", [True, False])
@pytest.mark.parametrize("full_matrices", [True, False]) @pytest.mark.parametrize("full_matrices", [True, False])
def test_svd(compute_uv, full_matrices, matrix_test): def test_svd(compute_uv, full_matrices, matrix_test):
......
import pytest
import pytensor
from tests.link.pytorch.test_basic import compare_pytorch_and_py
@pytest.mark.parametrize(
"mode",
(
"complete",
"reduced",
"r",
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)),
),
)
def test_qr(mode, matrix_test):
x, test_value = matrix_test
outs = pytensor.tensor.slinalg.qr(x, mode=mode)
compare_pytorch_and_py([x], outs, [test_value])
from functools import partial from functools import partial
import numpy as np import numpy as np
import numpy.linalg
import pytest import pytest
from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_almost_equal
...@@ -25,7 +24,6 @@ from pytensor.tensor.nlinalg import ( ...@@ -25,7 +24,6 @@ from pytensor.tensor.nlinalg import (
matrix_power, matrix_power,
norm, norm,
pinv, pinv,
qr,
slogdet, slogdet,
svd, svd,
tensorinv, tensorinv,
...@@ -122,102 +120,6 @@ def test_matrix_dot(): ...@@ -122,102 +120,6 @@ def test_matrix_dot():
assert _allclose(numpy_sol, pytensor_sol) assert _allclose(numpy_sol, pytensor_sol)
def test_qr_modes():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix("A", dtype=config.floatX)
a = rng.random((4, 4)).astype(config.floatX)
f = function([A], qr(A))
t_qr = f(a)
n_qr = np.linalg.qr(a)
assert _allclose(n_qr, t_qr)
for mode in ["reduced", "r", "raw"]:
f = function([A], qr(A, mode))
t_qr = f(a)
n_qr = np.linalg.qr(a, mode)
if isinstance(n_qr, list | tuple):
assert _allclose(n_qr[0], t_qr[0])
assert _allclose(n_qr[1], t_qr[1])
else:
assert _allclose(n_qr, t_qr)
try:
n_qr = np.linalg.qr(a, "complete")
f = function([A], qr(A, "complete"))
t_qr = f(a)
assert _allclose(n_qr, t_qr)
except TypeError as e:
assert "name 'complete' is not defined" in str(e)
@pytest.mark.parametrize(
"shape, gradient_test_case, mode",
(
[(s, c, "reduced") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
+ [(s, c, "complete") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
+ [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]]
+ [((3, 3), 0, "raw")]
),
ids=(
[
f"shape={s}, gradient_test_case={c}, mode=reduced"
for s in [(3, 3), (6, 3), (3, 6)]
for c in ["Q", "R", "both"]
]
+ [
f"shape={s}, gradient_test_case={c}, mode=complete"
for s in [(3, 3), (6, 3), (3, 6)]
for c in ["Q", "R", "both"]
]
+ [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]]
+ ["shape=(3, 3), gradient_test_case=Q, mode=raw"]
),
)
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_qr_grad(shape, gradient_test_case, mode, is_complex):
rng = np.random.default_rng(utt.fetch_seed())
def _test_fn(x, case=2, mode="reduced"):
if case == 0:
return qr(x, mode=mode)[0].sum()
elif case == 1:
return qr(x, mode=mode)[1].sum()
elif case == 2:
Q, R = qr(x, mode=mode)
return Q.sum() + R.sum()
if is_complex:
pytest.xfail("Complex inputs currently not supported by verify_grad")
m, n = shape
a = rng.standard_normal(shape).astype(config.floatX)
if is_complex:
a += 1j * rng.standard_normal(shape).astype(config.floatX)
if mode == "raw":
with pytest.raises(NotImplementedError):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode),
[a],
rng=np.random,
)
elif mode == "complete" and m > n:
with pytest.raises(AssertionError):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode),
[a],
rng=np.random,
)
else:
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random
)
class TestSvd(utt.InferShapeTester): class TestSvd(utt.InferShapeTester):
op_class = SVD op_class = SVD
......
import functools import functools
import itertools import itertools
from functools import partial
from typing import Literal from typing import Literal
import numpy as np import numpy as np
import pytest import pytest
import scipy import scipy
from scipy import linalg as scipy_linalg
from pytensor import function, grad from pytensor import function, grad
from pytensor import tensor as pt from pytensor import tensor as pt
...@@ -26,6 +28,7 @@ from pytensor.tensor.slinalg import ( ...@@ -26,6 +28,7 @@ from pytensor.tensor.slinalg import (
lu_factor, lu_factor,
lu_solve, lu_solve,
pivot_to_permutation, pivot_to_permutation,
qr,
solve, solve,
solve_continuous_lyapunov, solve_continuous_lyapunov,
solve_discrete_are, solve_discrete_are,
...@@ -1088,3 +1091,104 @@ def test_block_diagonal_blockwise(): ...@@ -1088,3 +1091,104 @@ def test_block_diagonal_blockwise():
B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX) B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX)
result = block_diag(A, B).eval() result = block_diag(A, B).eval()
assert result.shape == (10, batch_size, 6, 6) assert result.shape == (10, batch_size, 6, 6)
@pytest.mark.parametrize(
"mode, names",
[
("economic", ["Q", "R"]),
("full", ["Q", "R"]),
("r", ["R"]),
("raw", ["H", "tau", "R"]),
],
)
@pytest.mark.parametrize("pivoting", [True, False])
def test_qr_modes(mode, names, pivoting):
rng = np.random.default_rng(utt.fetch_seed())
A_val = rng.random((4, 4)).astype(config.floatX)
if pivoting:
names = [*names, "pivots"]
A = tensor("A", dtype=config.floatX, shape=(None, None))
f = function([A], qr(A, mode=mode, pivoting=pivoting))
outputs_pt = f(A_val)
outputs_sp = scipy_linalg.qr(A_val, mode=mode, pivoting=pivoting)
if mode == "raw":
# The first output of scipy's qr is a tuple when mode is raw; flatten it for easier iteration
outputs_sp = (*outputs_sp[0], *outputs_sp[1:])
elif mode == "r" and not pivoting:
# Here there's only one output from the pytensor function; wrap it in a list for iteration
outputs_pt = [outputs_pt]
for out_pt, out_sp, name in zip(outputs_pt, outputs_sp, names):
np.testing.assert_allclose(out_pt, out_sp, err_msg=f"{name} disagrees")
@pytest.mark.parametrize(
"shape, gradient_test_case, mode",
(
[(s, c, "economic") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
+ [(s, c, "full") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
+ [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]]
+ [((3, 3), 0, "raw")]
),
ids=(
[
f"shape={s}, gradient_test_case={c}, mode=economic"
for s in [(3, 3), (6, 3), (3, 6)]
for c in ["Q", "R", "both"]
]
+ [
f"shape={s}, gradient_test_case={c}, mode=full"
for s in [(3, 3), (6, 3), (3, 6)]
for c in ["Q", "R", "both"]
]
+ [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]]
+ ["shape=(3, 3), gradient_test_case=Q, mode=raw"]
),
)
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_qr_grad(shape, gradient_test_case, mode, is_complex):
rng = np.random.default_rng(utt.fetch_seed())
def _test_fn(x, case=2, mode="reduced"):
if case == 0:
return qr(x, mode=mode)[0].sum()
elif case == 1:
return qr(x, mode=mode)[1].sum()
elif case == 2:
Q, R = qr(x, mode=mode)
return Q.sum() + R.sum()
if is_complex:
pytest.xfail("Complex inputs currently not supported by verify_grad")
m, n = shape
a = rng.standard_normal(shape).astype(config.floatX)
if is_complex:
a += 1j * rng.standard_normal(shape).astype(config.floatX)
if mode == "raw":
with pytest.raises(NotImplementedError):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode),
[a],
rng=np.random,
)
elif mode == "full" and m > n:
with pytest.raises(AssertionError):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode),
[a],
rng=np.random,
)
else:
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论