提交 25538faa authored 作者: abalkin's avatar abalkin 提交者: Frederic

Split Eig and Eigh into different classes.

上级 940bd5a8
...@@ -396,6 +396,8 @@ cholesky = Cholesky() ...@@ -396,6 +396,8 @@ cholesky = Cholesky()
class CholeskyGrad(Op): class CholeskyGrad(Op):
"""
"""
def __init__(self, lower=True): def __init__(self, lower=True):
self.lower = lower self.lower = lower
self.destructive = False self.destructive = False
...@@ -488,7 +490,7 @@ class MatrixPinv(Op): ...@@ -488,7 +490,7 @@ class MatrixPinv(Op):
This method is not faster then `matrix_inverse`. Its strength comes from This method is not faster then `matrix_inverse`. Its strength comes from
that it works for non-square matrices. that it works for non-square matrices.
If you have a square matrix though, `matrix_inverse` can be both more If you have a square matrix though, `matrix_inverse` can be both more
exact and faster to compute. Aslo this op does not get optimized into a exact and faster to compute. Also this op does not get optimized into a
solve op. solve op.
""" """
def __init__(self): def __init__(self):
...@@ -881,9 +883,7 @@ class Eig(Op): ...@@ -881,9 +883,7 @@ class Eig(Op):
"""Compute the eigenvalues and right eigenvectors of a square array. """Compute the eigenvalues and right eigenvectors of a square array.
""" """
_numop = staticmethod(numpy.linalg.eig)
def __init__(self, numop):
self._numop = numop
def props(self): def props(self):
"""Function exposing different properties of each instance of the """Function exposing different properties of each instance of the
...@@ -920,6 +920,14 @@ class Eig(Op): ...@@ -920,6 +920,14 @@ class Eig(Op):
def __str__(self): def __str__(self):
return self._numop.__name__.capitalize() return self._numop.__name__.capitalize()
eig = Eig()
class Eigh(Eig):
"""
Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix.
"""
_numop = staticmethod(numpy.linalg.eigh)
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
r"""The gradient function should return r"""The gradient function should return
...@@ -942,13 +950,12 @@ class Eig(Op): ...@@ -942,13 +950,12 @@ class Eig(Op):
x, = inputs x, = inputs
w, v = self(x) w, v = self(x)
gw, gv = g_outputs gw, gv = g_outputs
return [EigGrad()(x, w, v, gw, gv)] return [EighGrad()(x, w, v, gw, gv)]
eig = Eig(numpy.linalg.eig) eigh = Eigh()
eigh = Eig(numpy.linalg.eigh)
class EigGrad(Op): class EighGrad(Op):
"""Gradient of an eigensystem. """Gradient of an eigensystem of a Hermitian matrix.
""" """
def props(self): def props(self):
...@@ -969,7 +976,7 @@ class EigGrad(Op): ...@@ -969,7 +976,7 @@ class EigGrad(Op):
return Apply(self, [x, w, v, gw, gv], [x.type()]) return Apply(self, [x, w, v, gw, gv], [x.type()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
""" r"""
Implements the "reverse-mode" gradient for the eigensystem of Implements the "reverse-mode" gradient for the eigensystem of
a square matrix. a square matrix.
...@@ -979,7 +986,22 @@ class EigGrad(Op): ...@@ -979,7 +986,22 @@ class EigGrad(Op):
By definition of the eigensystem, By definition of the eigensystem,
.. math:: \sum_j x_{ij}\,v_{jn} = w_n\,v_{in}. .. math:: x\,v_n = w_n\,v_n.
.. math:: v_m^\dagger\,v_n = \delta_{mn}
Differentiating these equations we get:
.. math:: v_n + x \frac{\partial v_n}{\partial x}
= \frac{\partial w_n}{\partial x}\,v_n +
w_n\frac{\partial v_n}{\partial x}.
.. math:: v_m^\dagger\,\frac{\partial v_n}{\partial x} = 0
Multiplying both sides by :math:`v^\dagger` and using orthogonality of
eigenvectors, we find:
.. math:: \frac{\partial w_n}{\partial x_{ij}} = v_{in}\,v_{jn}
""" """
x, w, v, gw, gv = inputs x, w, v, gw, gv = inputs
N = x.shape[0] N = x.shape[0]
...@@ -989,9 +1011,8 @@ class EigGrad(Op): ...@@ -989,9 +1011,8 @@ class EigGrad(Op):
pinv = numpy.linalg.pinv pinv = numpy.linalg.pinv
diag = numpy.diag diag = numpy.diag
outer = numpy.outer outer = numpy.outer
gx = sum(gw[n]*outer(v[:,n], v[:,n]) + gx = sum(gw[n]*outer(v[:,n], v[:,n])
sum(gv[m,n]*outer(pinv(diag(w)-x)[m,:],v[:,n]) + outer(gv[:,n], pinv(diag(w)-x).dot(v[:,n]))
for m in xrange(N))
for n in xrange(N)) for n in xrange(N))
outputs[0][0] = gx outputs[0][0] = gx
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论