提交 a00d8924 authored 作者: abalkin's avatar abalkin 提交者: Frederic

Eliminated pinv from grad eigh calculations.

上级 143103d3
......@@ -1026,21 +1026,20 @@ class EighGrad(Op):
"""
x, w, v, W, V = inputs
N = x.shape[0]
if imported_scipy:
pinv = scipy.linalg.pinv
else:
pinv = numpy.linalg.pinv
diag = numpy.diag
outer = numpy.outer
I = numpy.eye(x.shape[0])
I = numpy.eye(N)
def Wterm(n):
return numpy.outer(v[:,n],v[:,n])*W[n]
return outer(v[:,n], v[:,n])*W[n]
def Vterm(n):
return numpy.outer(v[:,n],numpy.linalg.pinv(w[n]*I-x).dot(V[:,n]))
outputs[0][0] = numpy.sum(Wterm(n) + Vterm(n)for n in range(v.shape[1]))
G = sum(v[:,m]*V.T[n].dot(v[:,m])/(w[n]-w[m])
for m in xrange(N) if m != n)
return outer(v[:,n], G)
outputs[0][0] = sum(Wterm(n) + Vterm(n)
for n in range(N))
def infer_shape(self, node, shapes):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论