提交 cd478ba2 authored 作者: abalkin's avatar abalkin 提交者: Frederic

Corected use of UPLO parameter in Eigh.grad.

Numpy's eigh(a, 'L') (eigh(a, 'U')) is a function of tril(a) (triu(a)) only. This means that partial derivative of eigh(a, 'L') (eigh(a, 'U')) with respect to a[i,j] is zero for i < j (i > j). At the same time, non-zero components of the gradient must account for the fact that variation of the opposite triangle contributes to variation of two elements of Hermitian (symmetric) matrix.
上级 1219334d
......@@ -991,7 +991,13 @@ class EighGrad(Op):
"""
def __init__(self, UPLO='L'):
self.UPLO = UPLO
if UPLO == 'L':
self.tri0 = numpy.tril
self.tri1 = lambda a: numpy.triu(a, 1)
else:
self.tri0 = numpy.triu
self.tri1 = lambda a: numpy.tril(a, -1)
def props(self):
return ()
......@@ -1043,11 +1049,9 @@ class EighGrad(Op):
G = lambda n: sum(v[:,m]*V.T[n].dot(v[:,m])/(w[n]-w[m])
for m in xrange(N) if m != n)
tri = numpy.tri(N)
if self.UPLO == 'U':
tri = tri.T
outputs[0][0] = sum(outer(v[:,n], v[:,n]*W[n] + G(n))
for n in xrange(N))#*tri
g = sum(outer(v[:,n], v[:,n]*W[n] + G(n))
for n in xrange(N))
outputs[0][0] = self.tri0(g) + self.tri1(g).T
def infer_shape(self, node, shapes):
return [shapes[0]]
......
......@@ -511,5 +511,7 @@ class test_Eigh(test_Eig):
def test_grad(self):
S = self.S
utt.verify_grad(lambda x: self.op(x + x.T)[0], [S], rng=self.rng)
utt.verify_grad(lambda x: self.op(x + x.T)[1], [S], rng=self.rng)
utt.verify_grad(lambda x: self.op(x)[0], [S], rng=self.rng)
utt.verify_grad(lambda x: self.op(x)[1], [S], rng=self.rng)
utt.verify_grad(lambda x: self.op(x, 'U')[0], [S], rng=self.rng)
utt.verify_grad(lambda x: self.op(x, 'U')[1], [S], rng=self.rng)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论