提交 8b215550 authored 作者: abalkin's avatar abalkin 提交者: Frederic

Cleaned up docstrings and added and explanation of how we deal with UPLO.

上级 cd478ba2
......@@ -959,21 +959,24 @@ class Eigh(Eig):
def grad(self, inputs, g_outputs):
r"""The gradient function should return
.. math:: \sum_n\left( W_n\frac{\partial\,\lambda_n}
{\partial X} +
\sum_k V_{nk}\frac{\partial\,\Psi_{nk}}
{\partial X}\right),
.. math:: \sum_n\left(W_n\frac{\partial\,w_n}
{\partial a_{ij}} +
\sum_k V_{nk}\frac{\partial\,v_{nk}}
{\partial a_{ij}}\right),
where [:math:`W`, :math:`V`] corresponds to ``g_outputs``,
:math:`X` to ``inputs``, and :math:`(\lambda, \Psi)=\mbox{eig}(X)`.
:math:`X` to ``inputs``, and :math:`(w, v)=\mbox{eig}(a)`.
.. math:: \frac{\partial\,\lambda_n}
{\partial X_{ij}} = \Psi_{ni}\,\Psi_{nj}
Analytic formulae for eigensystem gradients are well-known in
perturbation theory:
.. math:: \frac{\partial\,w_n}
{\partial a_{ij}} = v_{in}\,v_{jn}
.. math:: \frac{\partial\,\Psi_{ni}}
{\partial X_{jk}} =
\left((X-\lambda_n)^{+}\right)_{ij}\Psi_{nk}
.. math:: \frac{\partial\,v_{kn}}
{\partial a_{ij}} =
\sum_{m\ne n}\frac{v_{km}v_{jn}}{w_n-w_m}
"""
x, = inputs
w, v = self(x)
......@@ -1019,29 +1022,6 @@ class EighGrad(Op):
r"""
Implements the "reverse-mode" gradient for the eigensystem of
a square matrix.
Let
.. math:: w, v = \mbox{eig}(x).
By definition of the eigensystem,
.. math:: x\,v_n = w_n\,v_n.
.. math:: v_m^\dagger\,v_n = \delta_{mn}
Differentiating these equations we get:
.. math:: v_n + x \frac{\partial v_n}{\partial x}
= \frac{\partial w_n}{\partial x}\,v_n +
w_n\frac{\partial v_n}{\partial x}.
.. math:: v_m^\dagger\,\frac{\partial v_n}{\partial x} = 0
Multiplying both sides by :math:`v^\dagger` and using orthogonality of
eigenvectors, we find:
.. math:: \frac{\partial w_n}{\partial x_{ij}} = v_{in}\,v_{jn}
"""
x, w, v, W, V = inputs
N = x.shape[0]
......@@ -1051,6 +1031,15 @@ class EighGrad(Op):
for m in xrange(N) if m != n)
g = sum(outer(v[:,n], v[:,n]*W[n] + G(n))
for n in xrange(N))
# Numpy's eigh(a, 'L') (eigh(a, 'U')) is a function of tril(a)
# (triu(a)) only. This means that partial derivative of
# eigh(a, 'L') (eigh(a, 'U')) with respect to a[i,j] is zero
# for i < j (i > j). At the same time, non-zero components of
# the gradient must account for the fact that variation of the
# opposite triangle contributes to variation of two elements
# of Hermitian (symmetric) matrix. The following line
# implements the necessary logic.
outputs[0][0] = self.tri0(g) + self.tri1(g).T
def infer_shape(self, node, shapes):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论