Unverified 提交 2774599e authored 作者: Etienne Duchesne's avatar Etienne Duchesne 提交者: GitHub

Implement gradient for QR decomposition (#1303)

上级 8a7356ce
...@@ -5,12 +5,15 @@ from typing import Literal, cast ...@@ -5,12 +5,15 @@ from typing import Literal, cast
import numpy as np import numpy as np
import pytensor.tensor as pt
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.raise_op import Assert
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
...@@ -512,6 +515,80 @@ class QRFull(Op): ...@@ -512,6 +515,80 @@ class QRFull(Op):
else: else:
outputs[0][0] = res outputs[0][0] = res
def L_op(self, inputs, outputs, output_grads):
"""
Reverse-mode gradient of the QR function.
References
----------
.. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
.. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
"""
from pytensor.tensor.slinalg import solve_triangular
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
m, n = A.shape
def _H(x: ptb.TensorVariable):
return x.conj().mT
def _copyltu(x: ptb.TensorVariable):
return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1))
if self.mode == "raw":
raise NotImplementedError("Gradient of qr not implemented for mode=raw")
elif self.mode == "r":
# We need all the components of the QR to compute the gradient of A even if we only
# use the upper triangular component in the cost function.
Q, R = qr(A, mode="reduced")
dQ = Q.zeros_like()
dR = cast(ptb.TensorVariable, output_grads[0])
else:
Q, R = (cast(ptb.TensorVariable, x) for x in outputs)
if self.mode == "complete":
qr_assert_op = Assert(
"Gradient of qr not implemented for m x n matrices with m > n and mode=complete"
)
R = qr_assert_op(R, ptm.le(m, n))
new_output_grads = []
is_disconnected = [
isinstance(x.type, DisconnectedType) for x in output_grads
]
if all(is_disconnected):
# This should never be reached by Pytensor
return [DisconnectedType()()] # pragma: no cover
for disconnected, output_grad, output in zip(
is_disconnected, output_grads, [Q, R], strict=True
):
if disconnected:
new_output_grads.append(output.zeros_like())
else:
new_output_grads.append(output_grad)
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
# gradient expression when m >= n
M = R @ _H(dR) - _H(dQ) @ Q
K = dQ + Q @ _copyltu(M)
A_bar_m_ge_n = _H(solve_triangular(R, _H(K)))
# gradient expression when m < n
Y = A[:, m:]
U = R[:, :m]
dU, dV = dR[:, :m], dR[:, m:]
dQ_Yt_dV = dQ + Y @ _H(dV)
M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q
X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M))))
Y_bar = Q @ dV
A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1)
return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)]
def qr(a, mode="reduced"): def qr(a, mode="reduced"):
""" """
......
...@@ -152,6 +152,72 @@ def test_qr_modes(): ...@@ -152,6 +152,72 @@ def test_qr_modes():
assert "name 'complete' is not defined" in str(e) assert "name 'complete' is not defined" in str(e)
@pytest.mark.parametrize(
"shape, gradient_test_case, mode",
(
[(s, c, "reduced") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
+ [(s, c, "complete") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]]
+ [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]]
+ [((3, 3), 0, "raw")]
),
ids=(
[
f"shape={s}, gradient_test_case={c}, mode=reduced"
for s in [(3, 3), (6, 3), (3, 6)]
for c in ["Q", "R", "both"]
]
+ [
f"shape={s}, gradient_test_case={c}, mode=complete"
for s in [(3, 3), (6, 3), (3, 6)]
for c in ["Q", "R", "both"]
]
+ [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]]
+ ["shape=(3, 3), gradient_test_case=Q, mode=raw"]
),
)
@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"])
def test_qr_grad(shape, gradient_test_case, mode, is_complex):
rng = np.random.default_rng(utt.fetch_seed())
def _test_fn(x, case=2, mode="reduced"):
if case == 0:
return qr(x, mode=mode)[0].sum()
elif case == 1:
return qr(x, mode=mode)[1].sum()
elif case == 2:
Q, R = qr(x, mode=mode)
return Q.sum() + R.sum()
if is_complex:
pytest.xfail("Complex inputs currently not supported by verify_grad")
m, n = shape
a = rng.standard_normal(shape).astype(config.floatX)
if is_complex:
a += 1j * rng.standard_normal(shape).astype(config.floatX)
if mode == "raw":
with pytest.raises(NotImplementedError):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode),
[a],
rng=np.random,
)
elif mode == "complete" and m > n:
with pytest.raises(AssertionError):
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode),
[a],
rng=np.random,
)
else:
utt.verify_grad(
partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random
)
class TestSvd(utt.InferShapeTester): class TestSvd(utt.InferShapeTester):
op_class = SVD op_class = SVD
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论