提交 7e3caf1e authored 作者: abalkin's avatar abalkin 提交者: Frederic

Implemented Eig.grad(). Need to debug eigenvectors' component.

上级 e4316b31
......@@ -938,6 +938,51 @@ class Eig(Op):
{\partial X_{jk}} =
\left((X-\lambda_n)^{+}\right)_{ij}\Psi_{nk}
"""
return [grad_not_implemented(self, 0, x, "Work in progress.")]
x, = inputs
w, v = self(x)
gw, gv = g_outputs
return [EigGrad()(x, w, v, gw, gv)]
eig = Eig()
class EigGrad(Op):
"""Gradient of an eigensystem.
"""
def props(self):
return ()
def __hash__(self):
return hash((type(self), self.props()))
def __eq__(self, other):
return (type(self) == type(other) and self.props() == other.props())
def __str__(self):
return 'EigGrad'
def make_node(self, x, w, v, gw, gv):
x, w, v, gw, gv = map(as_tensor_variable, (x, w, v, gw, gv))
return Apply(self, [x, w, v, gw, gv], [x.type()])
def perform(self, node, inputs, outputs):
"""
"""
x, w, v, gw, gv = inputs
N = x.shape[0]
if imported_scipy:
pinv = scipy.linalg.pinv
else:
pinv = numpy.linalg.pinv
diag = numpy.diag
outer = numpy.outer
gx = sum(gw[n]*outer(v[:,n], v[:,n]) +
sum(gv[m,n]*outer(pinv(diag(w)-x)[m,:],v[:,n])
for m in xrange(N))
for n in xrange(N))
outputs[0][0] = gx
def infer_shape(self, node, shapes):
return [shapes[0]]
......@@ -29,7 +29,7 @@ from theano.sandbox.linalg.ops import (cholesky,
imported_scipy,
Eig,
)
from theano.sandbox.linalg import eig
from nose.plugins.skip import SkipTest
......@@ -475,16 +475,19 @@ class test_Eig(utt.InferShapeTester):
super(test_Eig, self).setUp()
self.op_class = Eig
self.op = Eig()
def test_infer_shape(self):
rng = numpy.random.RandomState(utt.fetch_seed())
A = theano.tensor.matrix()
X = numpy.asarray(rng.rand(5, 5),
self.rng = numpy.random.RandomState(utt.fetch_seed())
self.A = theano.tensor.matrix()
X = numpy.asarray(self.rng.rand(5, 5),
dtype=config.floatX)
self.S = X.dot(X.T)
def test_infer_shape(self):
A = self.A
S = self.S
self._compile_and_check([A], # theano.function inputs
self.op(A), # theano.function outputs
# A must be square
[X.dot(X.T)],
# S must be square
[S],
self.op_class)
def test_eval(self):
import math
......@@ -497,3 +500,10 @@ class test_Eig(utt.InferShapeTester):
assert_array_almost_equal(w, [1, -1])
x = math.sqrt(2)/2
assert_array_almost_equal(v, [[x, -x], [x, x]])
def test_grad(self):
S = self.S
def fun(x):
w, v = eig(x)
return w.sum() + v.sum()*0
utt.verify_grad(fun, [S], rng=self.rng)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论