提交 c9c8017b authored 作者: abalkin's avatar abalkin 提交者: Frederic

Debugging math formulae.

上级 25538faa
...@@ -928,6 +928,20 @@ class Eigh(Eig): ...@@ -928,6 +928,20 @@ class Eigh(Eig):
""" """
_numop = staticmethod(numpy.linalg.eigh) _numop = staticmethod(numpy.linalg.eigh)
def make_node(self, x):
x = as_tensor_variable(x)
w = theano.tensor.vector(dtype='float64')
v = theano.tensor.matrix(dtype=x.dtype)
return Apply(self, [x], [w, v])
def perform(self, node, (x,), (w, v)):
try:
w[0], v[0] = self._numop(x)
except numpy.linalg.LinAlgError:
logger.debug('Failed to find %s of %s' % (node.inputs[0],
self._numop.__name__))
raise
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
r"""The gradient function should return r"""The gradient function should return
...@@ -1003,7 +1017,7 @@ class EighGrad(Op): ...@@ -1003,7 +1017,7 @@ class EighGrad(Op):
.. math:: \frac{\partial w_n}{\partial x_{ij}} = v_{in}\,v_{jn} .. math:: \frac{\partial w_n}{\partial x_{ij}} = v_{in}\,v_{jn}
""" """
x, w, v, gw, gv = inputs x, w, v, W, V = inputs
N = x.shape[0] N = x.shape[0]
if imported_scipy: if imported_scipy:
pinv = scipy.linalg.pinv pinv = scipy.linalg.pinv
...@@ -1011,10 +1025,16 @@ class EighGrad(Op): ...@@ -1011,10 +1025,16 @@ class EighGrad(Op):
pinv = numpy.linalg.pinv pinv = numpy.linalg.pinv
diag = numpy.diag diag = numpy.diag
outer = numpy.outer outer = numpy.outer
gx = sum(gw[n]*outer(v[:,n], v[:,n]) I = numpy.eye(x.shape[0])
+ outer(gv[:,n], pinv(diag(w)-x).dot(v[:,n]))
for n in xrange(N)) def Wterm(n):
outputs[0][0] = gx return numpy.outer(v[:,n],v[:,n])*W[n]
def Vterm(n):
return numpy.outer(v[:,n],numpy.linalg.pinv(w[n]*I-x).dot(V[:,n]))
outputs[0][0] = numpy.sum(Wterm(n) + Vterm(n)for n in range(v.shape[1]))
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论