Unverified 提交 617964ff authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Refactor and update QR Op (#1518)

* Refactor QR * Update JAX QR dispatch * Update Torch QR dispatch * Update numba QR dispatch
上级 5024d54e
...@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import ( ...@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
KroneckerProduct, KroneckerProduct,
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
QRFull,
SLogDet, SLogDet,
) )
...@@ -67,16 +66,6 @@ def jax_funcify_MatrixInverse(op, **kwargs): ...@@ -67,16 +66,6 @@ def jax_funcify_MatrixInverse(op, **kwargs):
return matrix_inverse return matrix_inverse
@jax_funcify.register(QRFull)
def jax_funcify_QRFull(op, **kwargs):
mode = op.mode
def qr_full(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)
return qr_full
@jax_funcify.register(MatrixPinv) @jax_funcify.register(MatrixPinv)
def jax_funcify_Pinv(op, **kwargs): def jax_funcify_Pinv(op, **kwargs):
def pinv(x): def pinv(x):
......
...@@ -5,6 +5,7 @@ import jax ...@@ -5,6 +5,7 @@ import jax
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
...@@ -168,3 +169,13 @@ def jax_funcify_ChoSolve(op, **kwargs): ...@@ -168,3 +169,13 @@ def jax_funcify_ChoSolve(op, **kwargs):
) )
return cho_solve return cho_solve
@jax_funcify.register(QR)
def jax_funcify_QR(op, **kwargs):
mode = op.mode
def qr(x, mode=mode):
return jax.scipy.linalg.qr(x, mode=mode)
return qr
...@@ -283,7 +283,6 @@ class _LAPACK: ...@@ -283,7 +283,6 @@ class _LAPACK:
Called by scipy.linalg.lu_solve Called by scipy.linalg.lu_solve
""" """
...
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs") lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
functype = ctypes.CFUNCTYPE( functype = ctypes.CFUNCTYPE(
None, None,
...@@ -457,3 +456,90 @@ class _LAPACK: ...@@ -457,3 +456,90 @@ class _LAPACK:
_ptr_int, # INFO _ptr_int, # INFO
) )
return functype(lapack_ptr) return functype(lapack_ptr)
@classmethod
def numba_xgeqrf(cls, dtype):
"""
Compute the QR factorization of a general M-by-N matrix A.
Used in QR decomposition (no pivoting).
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xgeqp3(cls, dtype):
"""
Compute the QR factorization with column pivoting of a general M-by-N matrix A.
Used in QR decomposition with pivoting.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # JPVT
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xorgqr(cls, dtype):
"""
Generate the orthogonal matrix Q from a QR factorization (real types).
Used in QR decomposition to form Q.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "orgqr")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
_ptr_int, # K
float_pointer, # A
_ptr_int, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
@classmethod
def numba_xungqr(cls, dtype):
"""
Generate the unitary matrix Q from a QR factorization (complex types).
Used in QR decomposition to form Q for complex types.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "ungqr")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
_ptr_int, # K
float_pointer, # A
_ptr_int, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
...@@ -16,7 +16,6 @@ from pytensor.tensor.nlinalg import ( ...@@ -16,7 +16,6 @@ from pytensor.tensor.nlinalg import (
Eigh, Eigh,
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
QRFull,
SLogDet, SLogDet,
) )
...@@ -146,38 +145,3 @@ def numba_funcify_MatrixPinv(op, node, **kwargs): ...@@ -146,38 +145,3 @@ def numba_funcify_MatrixPinv(op, node, **kwargs):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype) return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
return matrixpinv return matrixpinv
@numba_funcify.register(QRFull)
def numba_funcify_QRFull(op, node, **kwargs):
mode = op.mode
if mode != "reduced":
warnings.warn(
(
"Numba will use object mode to allow the "
"`mode` argument to `numpy.linalg.qr`."
),
UserWarning,
)
if len(node.outputs) > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba_basic.numba_njit
def qr_full(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.qr(x, mode=mode)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always")
def qr_full(x):
return np.linalg.qr(inputs_cast(x))
return qr_full
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
import numpy as np import numpy as np
from pytensor import config
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
from pytensor.link.numba.dispatch.linalg.decomposition.lu import ( from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
...@@ -11,6 +12,14 @@ from pytensor.link.numba.dispatch.linalg.decomposition.lu import ( ...@@ -11,6 +12,14 @@ from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
_pivot_to_permutation, _pivot_to_permutation,
) )
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor
from pytensor.link.numba.dispatch.linalg.decomposition.qr import (
_qr_full_no_pivot,
_qr_full_pivot,
_qr_r_no_pivot,
_qr_r_pivot,
_qr_raw_no_pivot,
_qr_raw_pivot,
)
from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
...@@ -19,6 +28,7 @@ from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangul ...@@ -19,6 +28,7 @@ from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangul
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
...@@ -27,7 +37,7 @@ from pytensor.tensor.slinalg import ( ...@@ -27,7 +37,7 @@ from pytensor.tensor.slinalg import (
Solve, Solve,
SolveTriangular, SolveTriangular,
) )
from pytensor.tensor.type import complex_dtypes from pytensor.tensor.type import complex_dtypes, integer_dtypes
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
...@@ -311,3 +321,96 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -311,3 +321,96 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
) )
return cho_solve return cho_solve
@numba_funcify.register(QR)
def numba_funcify_QR(op, node, **kwargs):
mode = op.mode
check_finite = op.check_finite
pivoting = op.pivoting
overwrite_a = op.overwrite_a
dtype = node.inputs[0].dtype
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
integer_input = dtype in integer_dtypes
in_dtype = config.floatX if integer_input else dtype
@numba_njit(cache=False)
def qr(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to qr"
)
if integer_input:
a = a.astype(in_dtype)
if (mode == "full" or mode == "economic") and pivoting:
Q, R, P = _qr_full_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return Q, R, P
elif (mode == "full" or mode == "economic") and not pivoting:
Q, R = _qr_full_no_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return Q, R
elif mode == "r" and pivoting:
R, P = _qr_r_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return R, P
elif mode == "r" and not pivoting:
(R,) = _qr_r_no_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return R
elif mode == "raw" and pivoting:
H, tau, R, P = _qr_raw_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return H, tau, R, P
elif mode == "raw" and not pivoting:
H, tau, R = _qr_raw_no_pivot(
a,
mode=mode,
pivoting=pivoting,
overwrite_a=overwrite_a,
check_finite=check_finite,
)
return H, tau, R
else:
raise NotImplementedError(
f"QR mode={mode}, pivoting={pivoting} not supported in numba mode."
)
return qr
...@@ -8,6 +8,7 @@ import pytensor.link.pytorch.dispatch.elemwise ...@@ -8,6 +8,7 @@ import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.math import pytensor.link.pytorch.dispatch.math
import pytensor.link.pytorch.dispatch.extra_ops import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.nlinalg import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.slinalg
import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.subtensor import pytensor.link.pytorch.dispatch.subtensor
......
...@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import ( ...@@ -9,7 +9,6 @@ from pytensor.tensor.nlinalg import (
KroneckerProduct, KroneckerProduct,
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
QRFull,
SLogDet, SLogDet,
) )
...@@ -70,21 +69,6 @@ def pytorch_funcify_MatrixInverse(op, **kwargs): ...@@ -70,21 +69,6 @@ def pytorch_funcify_MatrixInverse(op, **kwargs):
return matrix_inverse return matrix_inverse
@pytorch_funcify.register(QRFull)
def pytorch_funcify_QRFull(op, **kwargs):
mode = op.mode
if mode == "raw":
raise NotImplementedError("raw mode not implemented in PyTorch")
def qr_full(x):
Q, R = torch.linalg.qr(x, mode=mode)
if mode == "r":
return R
return Q, R
return qr_full
@pytorch_funcify.register(MatrixPinv) @pytorch_funcify.register(MatrixPinv)
def pytorch_funcify_Pinv(op, **kwargs): def pytorch_funcify_Pinv(op, **kwargs):
hermitian = op.hermitian hermitian = op.hermitian
......
import torch
from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.slinalg import QR
@pytorch_funcify.register(QR)
def pytorch_funcify_QR(op, **kwargs):
mode = op.mode
if mode == "raw":
raise NotImplementedError("raw mode not implemented in PyTorch")
elif mode == "full":
mode = "complete"
elif mode == "economic":
mode = "reduced"
def qr(x):
Q, R = torch.linalg.qr(x, mode=mode)
if mode == "r":
return R
return Q, R
return qr
...@@ -5,15 +5,12 @@ from typing import Literal, cast ...@@ -5,15 +5,12 @@ from typing import Literal, cast
import numpy as np import numpy as np
import pytensor.tensor as pt
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.raise_op import Assert
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
...@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"): ...@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"):
return Eigh(UPLO)(a) return Eigh(UPLO)(a)
class QRFull(Op):
"""
Full QR Decomposition.
Computes the QR decomposition of a matrix.
Factor the matrix a as qr, where q is orthonormal
and r is upper-triangular.
"""
__props__ = ("mode",)
def __init__(self, mode):
self.mode = mode
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of qr function should be a matrix."
in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
q = matrix(dtype=out_dtype)
if self.mode != "raw":
r = matrix(dtype=out_dtype)
else:
r = vector(dtype=out_dtype)
if self.mode != "r":
q = matrix(dtype=out_dtype)
outputs = [q, r]
else:
outputs = [r]
return Apply(self, [x], outputs)
def perform(self, node, inputs, outputs):
(x,) = inputs
assert x.ndim == 2, "The input of qr function should be a matrix."
res = np.linalg.qr(x, self.mode)
if self.mode != "r":
outputs[0][0], outputs[1][0] = res
else:
outputs[0][0] = res
def L_op(self, inputs, outputs, output_grads):
"""
Reverse-mode gradient of the QR function.
References
----------
.. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
.. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
"""
from pytensor.tensor.slinalg import solve_triangular
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
m, n = A.shape
def _H(x: ptb.TensorVariable):
return x.conj().mT
def _copyltu(x: ptb.TensorVariable):
return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1))
if self.mode == "raw":
raise NotImplementedError("Gradient of qr not implemented for mode=raw")
elif self.mode == "r":
# We need all the components of the QR to compute the gradient of A even if we only
# use the upper triangular component in the cost function.
Q, R = qr(A, mode="reduced")
dQ = Q.zeros_like()
dR = cast(ptb.TensorVariable, output_grads[0])
else:
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
if self.mode == "complete":
qr_assert_op = Assert(
"Gradient of qr not implemented for m x n matrices with m > n and mode=complete"
)
R = qr_assert_op(R, ptm.le(m, n))
new_output_grads = []
is_disconnected = [
isinstance(x.type, DisconnectedType) for x in output_grads
]
if all(is_disconnected):
# This should never be reached by Pytensor
return [DisconnectedType()()] # pragma: no cover
for disconnected, output_grad, output in zip(
is_disconnected, output_grads, [Q, R], strict=True
):
if disconnected:
new_output_grads.append(output.zeros_like())
else:
new_output_grads.append(output_grad)
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
# gradient expression when m >= n
M = R @ _H(dR) - _H(dQ) @ Q
K = dQ + Q @ _copyltu(M)
A_bar_m_ge_n = _H(solve_triangular(R, _H(K)))
# gradient expression when m < n
Y = A[:, m:]
U = R[:, :m]
dU, dV = dR[:, :m], dR[:, m:]
dQ_Yt_dV = dQ + Y @ _H(dV)
M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q
X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M))))
Y_bar = Q @ dV
A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1)
return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)]
def qr(a, mode="reduced"):
"""
Computes the QR decomposition of a matrix.
Factor the matrix a as qr, where q
is orthonormal and r is upper-triangular.
Parameters
----------
a : array_like, shape (M, N)
Matrix to be factored.
mode : {'reduced', 'complete', 'r', 'raw'}, optional
If K = min(M, N), then
'reduced'
returns q, r with dimensions (M, K), (K, N)
'complete'
returns q, r with dimensions (M, M), (M, N)
'r'
returns r only with dimensions (K, N)
'raw'
returns h, tau with dimensions (N, M), (K,)
Note that array h returned in 'raw' mode is
transposed for calling Fortran.
Default mode is 'reduced'
Returns
-------
q : matrix of float or complex, optional
A matrix with orthonormal columns. When mode = 'complete' the
result is an orthogonal/unitary matrix depending on whether or
not a is real/complex. The determinant may be either +/- 1 in
that case.
r : matrix of float or complex, optional
The upper-triangular matrix.
"""
return QRFull(mode)(a)
class SVD(Op): class SVD(Op):
""" """
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
...@@ -1291,7 +1121,6 @@ __all__ = [ ...@@ -1291,7 +1121,6 @@ __all__ = [
"det", "det",
"eig", "eig",
"eigh", "eigh",
"qr",
"svd", "svd",
"lstsq", "lstsq",
"matrix_power", "matrix_power",
......
差异被折叠。
...@@ -29,12 +29,6 @@ def test_jax_basic_multiout(): ...@@ -29,12 +29,6 @@ def test_jax_basic_multiout():
outs = pt_nlinalg.eigh(x) outs = pt_nlinalg.eigh(x)
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.qr(x, mode="full")
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.qr(x, mode="reduced")
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.svd(x) outs = pt_nlinalg.svd(x)
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
......
...@@ -103,6 +103,18 @@ def test_jax_basic(): ...@@ -103,6 +103,18 @@ def test_jax_basic():
], ],
) )
def assert_fn(x, y):
np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3)
M = rng.normal(size=(3, 3))
X = M.dot(M.T)
outs = pt_slinalg.qr(x, mode="full")
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_slinalg.qr(x, mode="economic")
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
def test_jax_solve(): def test_jax_solve():
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
...@@ -186,60 +186,6 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -186,60 +186,6 @@ def test_matrix_inverses(op, x, exc, op_args):
) )
@pytest.mark.parametrize(
"x, mode, exc",
[
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
"reduced",
None,
),
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
"r",
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
"reduced",
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
"complete",
UserWarning,
),
],
)
def test_QRFull(x, mode, exc):
x, test_x = x
g = nlinalg.QRFull(mode)(x)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
g,
[test_x],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, full_matrices, compute_uv, exc", "x, full_matrices, compute_uv, exc",
[ [
......
...@@ -10,6 +10,7 @@ import pytensor.tensor as pt ...@@ -10,6 +10,7 @@ import pytensor.tensor as pt
from pytensor import In, config from pytensor import In, config
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU, LU,
QR,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
LUFactor, LUFactor,
...@@ -720,3 +721,70 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo ...@@ -720,3 +721,70 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
# Can never destroy non-contiguous inputs # Can never destroy non-contiguous inputs
np.testing.assert_allclose(b_val_not_contig, b_val) np.testing.assert_allclose(b_val_not_contig, b_val)
@pytest.mark.parametrize(
"mode, pivoting",
[("economic", False), ("full", True), ("r", False), ("raw", True)],
ids=["economic", "full_pivot", "r", "raw_pivot"],
)
@pytest.mark.parametrize(
"overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"]
)
def test_qr(mode, pivoting, overwrite_a):
shape = (5, 5)
rng = np.random.default_rng()
A = pt.tensor(
"A",
shape=shape,
dtype=config.floatX,
)
A_val = rng.normal(size=shape).astype(config.floatX)
qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting)
fn, res = compare_numba_and_py(
[In(A, mutable=overwrite_a)],
qr_outputs,
[A_val],
numba_mode=numba_inplace_mode,
inplace=True,
)
op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(op, QR)
destroy_map = op.destroy_map
if overwrite_a:
assert destroy_map == {0: [0]}
else:
assert destroy_map == {}
# Test F-contiguous input
val_f_contig = np.copy(A_val, order="F")
res_f_contig = fn(val_f_contig)
for x, x_f_contig in zip(res, res_f_contig, strict=True):
np.testing.assert_allclose(x, x_f_contig)
# Should always be destroyable
assert (A_val == val_f_contig).all() == (not overwrite_a)
# Test C-contiguous input
val_c_contig = np.copy(A_val, order="C")
res_c_contig = fn(val_c_contig)
for x, x_c_contig in zip(res, res_c_contig, strict=True):
np.testing.assert_allclose(x, x_c_contig)
# Cannot destroy C-contiguous input
np.testing.assert_allclose(val_c_contig, A_val)
# Test non-contiguous input
val_not_contig = np.repeat(A_val, 2, axis=0)[::2]
res_not_contig = fn(val_not_contig)
for x, x_not_contig in zip(res, res_not_contig, strict=True):
np.testing.assert_allclose(x, x_not_contig)
# Cannot destroy non-contiguous input
np.testing.assert_allclose(val_not_contig, A_val)
import numpy as np
import pytest
from pytensor import config
from pytensor.tensor.type import matrix
@pytest.fixture
def matrix_test():
rng = np.random.default_rng(213234)
M = rng.normal(size=(3, 3))
test_value = M.dot(M.T).astype(config.floatX)
x = matrix("x")
return x, test_value
...@@ -8,17 +8,6 @@ from pytensor.tensor.type import matrix ...@@ -8,17 +8,6 @@ from pytensor.tensor.type import matrix
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
@pytest.fixture
def matrix_test():
rng = np.random.default_rng(213234)
M = rng.normal(size=(3, 3))
test_value = M.dot(M.T).astype(config.floatX)
x = matrix("x")
return (x, test_value)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"func", "func",
(pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det), (pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
...@@ -34,22 +23,6 @@ def test_lin_alg_no_params(func, matrix_test): ...@@ -34,22 +23,6 @@ def test_lin_alg_no_params(func, matrix_test):
compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn) compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn)
@pytest.mark.parametrize(
"mode",
(
"complete",
"reduced",
"r",
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)),
),
)
def test_qr(mode, matrix_test):
x, test_value = matrix_test
outs = pt_nla.qr(x, mode=mode)
compare_pytorch_and_py([x], outs, [test_value])
@pytest.mark.parametrize("compute_uv", [True, False]) @pytest.mark.parametrize("compute_uv", [True, False])
@pytest.mark.parametrize("full_matrices", [True, False]) @pytest.mark.parametrize("full_matrices", [True, False])
def test_svd(compute_uv, full_matrices, matrix_test): def test_svd(compute_uv, full_matrices, matrix_test):
......
import pytest
import pytensor
from tests.link.pytorch.test_basic import compare_pytorch_and_py
@pytest.mark.parametrize(
"mode",
(
"complete",
"reduced",
"r",
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)),
),
)
def test_qr(mode, matrix_test):
x, test_value = matrix_test
outs = pytensor.tensor.slinalg.qr(x, mode=mode)
compare_pytorch_and_py([x], outs, [test_value])
from functools import partial from functools import partial
import numpy as np import numpy as np
import numpy.linalg
import pytest import pytest
from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_almost_equal
...@@ -25,7 +24,6 @@ from pytensor.tensor.nlinalg import ( ...@@ -25,7 +24,6 @@ from pytensor.tensor.nlinalg import (
matrix_power, matrix_power,
norm, norm,
pinv, pinv,
qr,
slogdet, slogdet,
svd, svd,
tensorinv, tensorinv,
...@@ -122,102 +120,6 @@ def test_matrix_dot(): ...@@ -122,102 +120,6 @@ def test_matrix_dot():
assert _allclose(numpy_sol, pytensor_sol) assert _allclose(numpy_sol, pytensor_sol)
def test_qr_modes():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix("A", dtype=config.floatX)
a = rng.random((4, 4)).astype(config.floatX)
f = function([A], qr(A))
t_qr = f(a)
n_qr = np.linalg.qr(a)
assert _allclose(n_qr, t_qr)
for mode in ["reduced", "r", "raw"]:
f = function([A], qr(A, mode))
t_qr = f(a)
n_qr = np.linalg.qr(a, mode)
if isinstance(n_qr, list | tuple):
assert _allclose(n_qr[0], t_qr[0])
assert _allclose(n_qr[1], t_qr[1])
else:
assert _allclose(n_qr, t_qr)
try:
n_qr = np.linalg.qr(a, "complete")
f = function([A], qr(A, "complete"))
t_qr = f(a)
assert _allclose(n_qr, t_qr)
except TypeError as e:
assert "name 'complete' is not defined" in str(e)
@pytest.mark.parametrize(
"shape, gradient_test_case, mode",
(
[(s, c, "reduced") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
+ [(s, c, "complete") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
+ [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]]
+ [((3, 3), 0, "raw")]
),
ids=(
[
f"shape={s}, gradient_test_case={c}, mode=reduced"
for s in [(3, 3), (6, 3), (3, 6)]
for c in ["Q", "R", "both"]
]
+ [
f"shape={s}, gradient_test_case={c}, mode=complete"
for s in [(3, 3), (6, 3), (3, 6)]
for c in ["Q", "R", "both"]
]
+ [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]]
+ ["shape=(3, 3), gradient_test_case=Q, mode=raw"]
),
)
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_qr_grad(shape, gradient_test_case, mode, is_complex):
rng = np.random.default_rng(utt.fetch_seed())
def _test_fn(x, case=2, mode="reduced"):
if case == 0:
return qr(x, mode=mode)[0].sum()
elif case == 1:
return qr(x, mode=mode)[1].sum()
elif case == 2:
Q, R = qr(x, mode=mode)
return Q.sum() + R.sum()
if is_complex:
pytest.xfail("Complex inputs currently not supported by verify_grad")
m, n = shape
a = rng.standard_normal(shape).astype(config.floatX)
if is_complex:
a += 1j * rng.standard_normal(shape).astype(config.floatX)
if mode == "raw":
with pytest.raises(NotImplementedError):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode),
[a],
rng=np.random,
)
elif mode == "complete" and m > n:
with pytest.raises(AssertionError):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode),
[a],
rng=np.random,
)
else:
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random
)
class TestSvd(utt.InferShapeTester): class TestSvd(utt.InferShapeTester):
op_class = SVD op_class = SVD
......
import functools import functools
import itertools import itertools
from functools import partial
from typing import Literal from typing import Literal
import numpy as np import numpy as np
import pytest import pytest
import scipy import scipy
from scipy import linalg as scipy_linalg
from pytensor import function, grad from pytensor import function, grad
from pytensor import tensor as pt from pytensor import tensor as pt
...@@ -26,6 +28,7 @@ from pytensor.tensor.slinalg import ( ...@@ -26,6 +28,7 @@ from pytensor.tensor.slinalg import (
lu_factor, lu_factor,
lu_solve, lu_solve,
pivot_to_permutation, pivot_to_permutation,
qr,
solve, solve,
solve_continuous_lyapunov, solve_continuous_lyapunov,
solve_discrete_are, solve_discrete_are,
...@@ -1088,3 +1091,104 @@ def test_block_diagonal_blockwise(): ...@@ -1088,3 +1091,104 @@ def test_block_diagonal_blockwise():
B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX) B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX)
result = block_diag(A, B).eval() result = block_diag(A, B).eval()
assert result.shape == (10, batch_size, 6, 6) assert result.shape == (10, batch_size, 6, 6)
@pytest.mark.parametrize(
"mode, names",
[
("economic", ["Q", "R"]),
("full", ["Q", "R"]),
("r", ["R"]),
("raw", ["H", "tau", "R"]),
],
)
@pytest.mark.parametrize("pivoting", [True, False])
def test_qr_modes(mode, names, pivoting):
rng = np.random.default_rng(utt.fetch_seed())
A_val = rng.random((4, 4)).astype(config.floatX)
if pivoting:
names = [*names, "pivots"]
A = tensor("A", dtype=config.floatX, shape=(None, None))
f = function([A], qr(A, mode=mode, pivoting=pivoting))
outputs_pt = f(A_val)
outputs_sp = scipy_linalg.qr(A_val, mode=mode, pivoting=pivoting)
if mode == "raw":
# The first output of scipy's qr is a tuple when mode is raw; flatten it for easier iteration
outputs_sp = (*outputs_sp[0], *outputs_sp[1:])
elif mode == "r" and not pivoting:
# Here there's only one output from the pytensor function; wrap it in a list for iteration
outputs_pt = [outputs_pt]
for out_pt, out_sp, name in zip(outputs_pt, outputs_sp, names):
np.testing.assert_allclose(out_pt, out_sp, err_msg=f"{name} disagrees")
@pytest.mark.parametrize(
"shape, gradient_test_case, mode",
(
[(s, c, "economic") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
+ [(s, c, "full") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
+ [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]]
+ [((3, 3), 0, "raw")]
),
ids=(
[
f"shape={s}, gradient_test_case={c}, mode=economic"
for s in [(3, 3), (6, 3), (3, 6)]
for c in ["Q", "R", "both"]
]
+ [
f"shape={s}, gradient_test_case={c}, mode=full"
for s in [(3, 3), (6, 3), (3, 6)]
for c in ["Q", "R", "both"]
]
+ [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]]
+ ["shape=(3, 3), gradient_test_case=Q, mode=raw"]
),
)
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_qr_grad(shape, gradient_test_case, mode, is_complex):
rng = np.random.default_rng(utt.fetch_seed())
def _test_fn(x, case=2, mode="reduced"):
if case == 0:
return qr(x, mode=mode)[0].sum()
elif case == 1:
return qr(x, mode=mode)[1].sum()
elif case == 2:
Q, R = qr(x, mode=mode)
return Q.sum() + R.sum()
if is_complex:
pytest.xfail("Complex inputs currently not supported by verify_grad")
m, n = shape
a = rng.standard_normal(shape).astype(config.floatX)
if is_complex:
a += 1j * rng.standard_normal(shape).astype(config.floatX)
if mode == "raw":
with pytest.raises(NotImplementedError):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode),
[a],
rng=np.random,
)
elif mode == "full" and m > n:
with pytest.raises(AssertionError):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode),
[a],
rng=np.random,
)
else:
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论