提交 b7c952bf authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Add gradient for `SVD`

上级 eb18f0ea
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable, Sequence
from functools import partial from functools import partial
from typing import Literal from typing import Literal, cast
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore from numpy.core.numeric import normalize_axis_tuple # type: ignore
...@@ -15,7 +15,7 @@ from pytensor.tensor import basic as ptb ...@@ -15,7 +15,7 @@ from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector
class MatrixPinv(Op): class MatrixPinv(Op):
...@@ -597,6 +597,121 @@ class SVD(Op): ...@@ -597,6 +597,121 @@ class SVD(Op):
else: else:
return [s_shape] return [s_shape]
def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
"""
Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here:
https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194
And the mxnet implementation described in ..[1]
References
----------
.. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017).
"""
def s_grad_only(
U: ptb.TensorVariable, VT: ptb.TensorVariable, ds: ptb.TensorVariable
) -> list[Variable]:
A_bar = (U.conj() * ds[..., None, :]) @ VT
return [A_bar]
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)
if not self.compute_uv:
# We need all the components of the SVD to compute the gradient of A even if we only use the singular values
# in the cost function.
U, _, VT = svd(A, full_matrices=False, compute_uv=True)
ds = cast(ptb.TensorVariable, output_grads[0])
return s_grad_only(U, VT, ds)
elif self.full_matrices:
raise NotImplementedError(
"Gradient of svd not implemented for full_matrices=True"
)
else:
U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs)
# Handle disconnected inputs
# If a user asked for all the matrices but then only used a subset in the cost function, the unused outputs
# will be DisconnectedType. We replace DisconnectedTypes with zero matrices of the correct shapes.
new_output_grads = []
is_disconnected = [
isinstance(x.type, DisconnectedType) for x in output_grads
]
if all(is_disconnected):
# This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient
# graph if its fully disconnected. It is included for completeness.
return [DisconnectedType()()] # pragma: no cover
elif is_disconnected == [True, False, True]:
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
# needing to re-compoute U and VT
ds = cast(ptb.TensorVariable, output_grads[1])
return s_grad_only(U, VT, ds)
for disconnected, output_grad, output in zip(
is_disconnected, output_grads, [U, s, VT]
):
if disconnected:
new_output_grads.append(output.zeros_like())
else:
new_output_grads.append(output_grad)
(dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in new_output_grads)
V = VT.T
dV = dVT.T
m, n = A.shape[-2:]
k = ptm.min((m, n))
eye = ptb.eye(k)
def h(t):
"""
Approximation of s_i ** 2 - s_j ** 2, from .. [1].
Robust to identical singular values (singular matrix input), although
gradients are still wrong in this case.
"""
eps = 1e-8
# sign(0) = 0 in pytensor, which defeats the whole purpose of this function
sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t))
return ptm.maximum(ptm.abs(t), eps) * sign_t
numer = ptb.ones((k, k)) - eye
denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None])
E = numer / denom
utgu = U.T @ dU
vtgv = VT @ dV
A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :]
A_bar = A_bar + eye * ds[..., :, None]
A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T))
A_bar = U.conj() @ A_bar @ VT
A_bar = ptb.switch(
ptm.eq(m, n),
A_bar,
ptb.switch(
ptm.lt(m, n),
A_bar
+ (
U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T)
).conj(),
A_bar
+ (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T,
),
)
return [A_bar]
def svd(a, full_matrices: bool = True, compute_uv: bool = True): def svd(a, full_matrices: bool = True, compute_uv: bool = True):
""" """
......
...@@ -8,6 +8,7 @@ from numpy.testing import assert_array_almost_equal ...@@ -8,6 +8,7 @@ from numpy.testing import assert_array_almost_equal
import pytensor import pytensor
from pytensor import function from pytensor import function
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import _allclose from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD, SVD,
...@@ -215,6 +216,80 @@ class TestSvd(utt.InferShapeTester): ...@@ -215,6 +216,80 @@ class TestSvd(utt.InferShapeTester):
outputs = [outputs] outputs = [outputs]
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False) self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)
@pytest.mark.parametrize(
"compute_uv, full_matrices, gradient_test_case",
[(False, False, 0)]
+ [(True, False, i) for i in range(8)]
+ [(True, True, i) for i in range(8)],
ids=(
["compute_uv=False, full_matrices=False"]
+ [
f"compute_uv=True, full_matrices=False, gradient={grad}"
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V", "None"]
]
+ [
f"compute_uv=True, full_matrices=True, gradient={grad}"
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V", "None"]
]
),
)
@pytest.mark.parametrize(
"shape", [(3, 3), (4, 3), (3, 4)], ids=["(3,3)", "(4,3)", "(3,4)"]
)
@pytest.mark.parametrize(
"batched", [True, False], ids=["batched=True", "batched=False"]
)
def test_grad(self, compute_uv, full_matrices, gradient_test_case, shape, batched):
rng = np.random.default_rng(utt.fetch_seed())
if batched:
shape = (4, *shape)
A_v = self.rng.normal(size=shape).astype(config.floatX)
if full_matrices:
with pytest.raises(
NotImplementedError,
match="Gradient of svd not implemented for full_matrices=True",
):
U, s, V = svd(
self.A, compute_uv=compute_uv, full_matrices=full_matrices
)
pytensor.grad(s.sum(), self.A)
elif compute_uv:
def svd_fn(A, case=0):
U, s, V = svd(A, compute_uv=compute_uv, full_matrices=full_matrices)
if case == 0:
return U.sum()
elif case == 1:
return s.sum()
elif case == 2:
return V.sum()
elif case == 3:
return U.sum() + s.sum()
elif case == 4:
return s.sum() + V.sum()
elif case == 5:
return U.sum() + V.sum()
elif case == 6:
return U.sum() + s.sum() + V.sum()
elif case == 7:
# All inputs disconnected
return as_tensor_variable(3.0)
utt.verify_grad(
partial(svd_fn, case=gradient_test_case),
[A_v],
rng=rng,
)
else:
utt.verify_grad(
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
[A_v],
rng=rng,
)
def test_tensorsolve(): def test_tensorsolve():
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论