提交 1219334d authored 作者: abalkin's avatar abalkin 提交者: Frederic

Implemented UPLO parameter for eigh.

上级 b9a6f0b7
......@@ -933,6 +933,15 @@ class Eigh(Eig):
"""
_numop = staticmethod(numpy.linalg.eigh)
def __init__(self, UPLO='L'):
self.UPLO = UPLO
def __str__(self):
return 'Eigh{%s}' % self.UPLO
def props(self):
return self.UPLO,
def make_node(self, x):
x = as_tensor_variable(x)
w = theano.tensor.vector(dtype='float64')
......@@ -941,7 +950,7 @@ class Eigh(Eig):
def perform(self, node, (x,), (w, v)):
try:
w[0], v[0] = self._numop(x)
w[0], v[0] = self._numop(x, self.UPLO)
except numpy.linalg.LinAlgError:
logger.debug('Failed to find %s of %s' % (node.inputs[0],
self._numop.__name__))
......@@ -971,14 +980,18 @@ class Eigh(Eig):
# Replace gradients wrt disconnected variables with
# zeros. This is a work-around for issue #1063.
gw, gv = _zero_disconnected([w, v], g_outputs)
return [EighGrad()(x, w, v, gw, gv)]
return [EighGrad(self.UPLO)(x, w, v, gw, gv)]
eigh = Eigh()
def eigh(a, UPLO='L'):
return Eigh(UPLO)(a)
class EighGrad(Op):
"""Gradient of an eigensystem of a Hermitian matrix.
"""
def __init__(self, UPLO='L'):
self.UPLO = UPLO
def props(self):
return ()
......@@ -989,7 +1002,7 @@ class EighGrad(Op):
return (type(self) == type(other) and self.props() == other.props())
def __str__(self):
return 'EigGrad'
return 'EighGrad{%s}' % self.UPLO
def make_node(self, x, w, v, gw, gv):
......@@ -1030,8 +1043,11 @@ class EighGrad(Op):
G = lambda n: sum(v[:,m]*V.T[n].dot(v[:,m])/(w[n]-w[m])
for m in xrange(N) if m != n)
tri = numpy.tri(N)
if self.UPLO == 'U':
tri = tri.T
outputs[0][0] = sum(outer(v[:,n], v[:,n]*W[n] + G(n))
for n in xrange(N))
for n in xrange(N))#*tri
def infer_shape(self, node, shapes):
return [shapes[0]]
......
......@@ -499,7 +499,16 @@ class test_Eig(utt.InferShapeTester):
assert_array_almost_equal(numpy.dot(x,v), w * v)
class test_Eigh(test_Eig):
op = eigh
op = staticmethod(eigh)
def test_uplo(self):
S = self.S
a = theano.tensor.matrix()
wu, vu = [out.eval({a: S}) for out in self.op(a, 'U')]
wl, vl = [out.eval({a: S}) for out in self.op(a, 'L')]
assert_array_almost_equal(wu, wl)
assert_array_almost_equal(vu*numpy.sign(vu[0,:]),
vl*numpy.sign(vl[0,:]))
def test_grad(self):
S = self.S
utt.verify_grad(lambda x: self.op(x + x.T)[0], [S], rng=self.rng)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论