提交 1219334d authored 作者: abalkin's avatar abalkin 提交者: Frederic

Implemented UPLO parameter for eigh.

上级 b9a6f0b7
...@@ -933,6 +933,15 @@ class Eigh(Eig): ...@@ -933,6 +933,15 @@ class Eigh(Eig):
""" """
_numop = staticmethod(numpy.linalg.eigh) _numop = staticmethod(numpy.linalg.eigh)
def __init__(self, UPLO='L'):
self.UPLO = UPLO
def __str__(self):
return 'Eigh{%s}' % self.UPLO
def props(self):
return self.UPLO,
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
w = theano.tensor.vector(dtype='float64') w = theano.tensor.vector(dtype='float64')
...@@ -941,7 +950,7 @@ class Eigh(Eig): ...@@ -941,7 +950,7 @@ class Eigh(Eig):
def perform(self, node, (x,), (w, v)): def perform(self, node, (x,), (w, v)):
try: try:
w[0], v[0] = self._numop(x) w[0], v[0] = self._numop(x, self.UPLO)
except numpy.linalg.LinAlgError: except numpy.linalg.LinAlgError:
logger.debug('Failed to find %s of %s' % (node.inputs[0], logger.debug('Failed to find %s of %s' % (node.inputs[0],
self._numop.__name__)) self._numop.__name__))
...@@ -971,14 +980,18 @@ class Eigh(Eig): ...@@ -971,14 +980,18 @@ class Eigh(Eig):
# Replace gradients wrt disconnected variables with # Replace gradients wrt disconnected variables with
# zeros. This is a work-around for issue #1063. # zeros. This is a work-around for issue #1063.
gw, gv = _zero_disconnected([w, v], g_outputs) gw, gv = _zero_disconnected([w, v], g_outputs)
return [EighGrad()(x, w, v, gw, gv)] return [EighGrad(self.UPLO)(x, w, v, gw, gv)]
eigh = Eigh() def eigh(a, UPLO='L'):
return Eigh(UPLO)(a)
class EighGrad(Op): class EighGrad(Op):
"""Gradient of an eigensystem of a Hermitian matrix. """Gradient of an eigensystem of a Hermitian matrix.
""" """
def __init__(self, UPLO='L'):
self.UPLO = UPLO
def props(self): def props(self):
return () return ()
...@@ -989,7 +1002,7 @@ class EighGrad(Op): ...@@ -989,7 +1002,7 @@ class EighGrad(Op):
return (type(self) == type(other) and self.props() == other.props()) return (type(self) == type(other) and self.props() == other.props())
def __str__(self): def __str__(self):
return 'EigGrad' return 'EighGrad{%s}' % self.UPLO
def make_node(self, x, w, v, gw, gv): def make_node(self, x, w, v, gw, gv):
...@@ -1030,8 +1043,11 @@ class EighGrad(Op): ...@@ -1030,8 +1043,11 @@ class EighGrad(Op):
G = lambda n: sum(v[:,m]*V.T[n].dot(v[:,m])/(w[n]-w[m]) G = lambda n: sum(v[:,m]*V.T[n].dot(v[:,m])/(w[n]-w[m])
for m in xrange(N) if m != n) for m in xrange(N) if m != n)
tri = numpy.tri(N)
if self.UPLO == 'U':
tri = tri.T
outputs[0][0] = sum(outer(v[:,n], v[:,n]*W[n] + G(n)) outputs[0][0] = sum(outer(v[:,n], v[:,n]*W[n] + G(n))
for n in xrange(N)) for n in xrange(N))#*tri
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
......
...@@ -499,7 +499,16 @@ class test_Eig(utt.InferShapeTester): ...@@ -499,7 +499,16 @@ class test_Eig(utt.InferShapeTester):
assert_array_almost_equal(numpy.dot(x,v), w * v) assert_array_almost_equal(numpy.dot(x,v), w * v)
class test_Eigh(test_Eig): class test_Eigh(test_Eig):
op = eigh op = staticmethod(eigh)
def test_uplo(self):
S = self.S
a = theano.tensor.matrix()
wu, vu = [out.eval({a: S}) for out in self.op(a, 'U')]
wl, vl = [out.eval({a: S}) for out in self.op(a, 'L')]
assert_array_almost_equal(wu, wl)
assert_array_almost_equal(vu*numpy.sign(vu[0,:]),
vl*numpy.sign(vl[0,:]))
def test_grad(self): def test_grad(self):
S = self.S S = self.S
utt.verify_grad(lambda x: self.op(x + x.T)[0], [S], rng=self.rng) utt.verify_grad(lambda x: self.op(x + x.T)[0], [S], rng=self.rng)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论