提交 b9a6f0b7 authored 作者: abalkin's avatar abalkin 提交者: Frederic

Simplified grad eigh calculations.

上级 a00d8924
......@@ -1026,21 +1026,12 @@ class EighGrad(Op):
"""
x, w, v, W, V = inputs
N = x.shape[0]
diag = numpy.diag
outer = numpy.outer
I = numpy.eye(N)
def Wterm(n):
return outer(v[:,n], v[:,n])*W[n]
def Vterm(n):
G = sum(v[:,m]*V.T[n].dot(v[:,m])/(w[n]-w[m])
for m in xrange(N) if m != n)
return outer(v[:,n], G)
outputs[0][0] = sum(Wterm(n) + Vterm(n)
for n in range(N))
G = lambda n: sum(v[:,m]*V.T[n].dot(v[:,m])/(w[n]-w[m])
for m in xrange(N) if m != n)
outputs[0][0] = sum(outer(v[:,n], v[:,n]*W[n] + G(n))
for n in xrange(N))
def infer_shape(self, node, shapes):
return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论