提交 6171e71f authored 作者: Frederic's avatar Frederic

pep8

上级 1d63fa25
...@@ -662,10 +662,10 @@ class Solve(Op): ...@@ -662,10 +662,10 @@ class Solve(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
Ashape, Bshape = shapes Ashape, Bshape = shapes
rows = Ashape[1] rows = Ashape[1]
if len(Bshape) == 1: # b is a Vector if len(Bshape) == 1: # b is a Vector
return [(rows,)] return [(rows,)]
else: else:
cols = Bshape[1] # b is a Matrix cols = Bshape[1] # b is a Matrix
return [(rows, cols)] return [(rows, cols)]
solve = Solve() # general solve solve = Solve() # general solve
...@@ -879,6 +879,7 @@ class A_Xinv_b(Op): ...@@ -879,6 +879,7 @@ class A_Xinv_b(Op):
gb = matrix_dot(ix.T, a.T, gz) gb = matrix_dot(ix.T, a.T, gz)
return [ga, gX, gb] return [ga, gX, gb]
class Eig(Op): class Eig(Op):
"""Compute the eigenvalues and right eigenvectors of a square array. """Compute the eigenvalues and right eigenvectors of a square array.
...@@ -916,24 +917,27 @@ class Eig(Op): ...@@ -916,24 +917,27 @@ class Eig(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
n = shapes[0][0] n = shapes[0][0]
return [(n,), (n,n)] return [(n,), (n, n)]
def __str__(self): def __str__(self):
return self._numop.__name__.capitalize() return self._numop.__name__.capitalize()
eig = Eig() eig = Eig()
def _zero_disconnected(outputs, grads): def _zero_disconnected(outputs, grads):
return [o.zeros_like() return [o.zeros_like()
if isinstance(g.type, DisconnectedType) else g if isinstance(g.type, DisconnectedType) else g
for o, g in zip(outputs, grads)] for o, g in zip(outputs, grads)]
class Eigh(Eig): class Eigh(Eig):
""" """
Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix. Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix.
""" """
_numop = staticmethod(numpy.linalg.eigh) _numop = staticmethod(numpy.linalg.eigh)
def __init__(self, UPLO='L'): def __init__(self, UPLO='L'):
self.UPLO = UPLO self.UPLO = UPLO
...@@ -962,9 +966,8 @@ class Eigh(Eig): ...@@ -962,9 +966,8 @@ class Eigh(Eig):
except numpy.linalg.LinAlgError: except numpy.linalg.LinAlgError:
logger.debug('Failed to find %s of %s' % (self._numop.__name__, logger.debug('Failed to find %s of %s' % (self._numop.__name__,
node.inputs[0])) node.inputs[0]))
raise raise
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
r"""The gradient function should return r"""The gradient function should return
...@@ -972,7 +975,7 @@ class Eigh(Eig): ...@@ -972,7 +975,7 @@ class Eigh(Eig):
{\partial a_{ij}} + {\partial a_{ij}} +
\sum_k V_{nk}\frac{\partial\,v_{nk}} \sum_k V_{nk}\frac{\partial\,v_{nk}}
{\partial a_{ij}}\right), {\partial a_{ij}}\right),
where [:math:`W`, :math:`V`] corresponds to ``g_outputs``, where [:math:`W`, :math:`V`] corresponds to ``g_outputs``,
:math:`a` to ``inputs``, and :math:`(w, v)=\mbox{eig}(a)`. :math:`a` to ``inputs``, and :math:`(w, v)=\mbox{eig}(a)`.
...@@ -984,7 +987,7 @@ class Eigh(Eig): ...@@ -984,7 +987,7 @@ class Eigh(Eig):
.. math:: \frac{\partial\,v_{kn}} .. math:: \frac{\partial\,v_{kn}}
{\partial a_{ij}} = {\partial a_{ij}} =
\sum_{m\ne n}\frac{v_{km}v_{jn}}{w_n-w_m} \sum_{m\ne n}\frac{v_{km}v_{jn}}{w_n-w_m}
""" """
x, = inputs x, = inputs
...@@ -994,8 +997,10 @@ class Eigh(Eig): ...@@ -994,8 +997,10 @@ class Eigh(Eig):
gw, gv = _zero_disconnected([w, v], g_outputs) gw, gv = _zero_disconnected([w, v], g_outputs)
return [EighGrad(self.UPLO)(x, w, v, gw, gv)] return [EighGrad(self.UPLO)(x, w, v, gw, gv)]
def eigh(a, UPLO='L'): def eigh(a, UPLO='L'):
return Eigh(UPLO)(a) return Eigh(UPLO)(a)
class EighGrad(Op): class EighGrad(Op):
"""Gradient of an eigensystem of a Hermitian matrix. """Gradient of an eigensystem of a Hermitian matrix.
...@@ -1009,7 +1014,7 @@ class EighGrad(Op): ...@@ -1009,7 +1014,7 @@ class EighGrad(Op):
else: else:
self.tri0 = numpy.triu self.tri0 = numpy.triu
self.tri1 = lambda a: numpy.tril(a, -1) self.tri1 = lambda a: numpy.tril(a, -1)
def props(self): def props(self):
return () return ()
...@@ -1021,7 +1026,6 @@ class EighGrad(Op): ...@@ -1021,7 +1026,6 @@ class EighGrad(Op):
def __str__(self): def __str__(self):
return 'EighGrad{%s}' % self.UPLO return 'EighGrad{%s}' % self.UPLO
def make_node(self, x, w, v, gw, gv): def make_node(self, x, w, v, gw, gv):
x, w, v, gw, gv = map(as_tensor_variable, (x, w, v, gw, gv)) x, w, v, gw, gv = map(as_tensor_variable, (x, w, v, gw, gv))
...@@ -1036,9 +1040,9 @@ class EighGrad(Op): ...@@ -1036,9 +1040,9 @@ class EighGrad(Op):
N = x.shape[0] N = x.shape[0]
outer = numpy.outer outer = numpy.outer
G = lambda n: sum(v[:,m]*V.T[n].dot(v[:,m])/(w[n]-w[m]) G = lambda n: sum(v[:, m] * V.T[n].dot(v[:, m]) / (w[n] - w[m])
for m in xrange(N) if m != n) for m in xrange(N) if m != n)
g = sum(outer(v[:,n], v[:,n]*W[n] + G(n)) g = sum(outer(v[:, n], v[:, n] * W[n] + G(n))
for n in xrange(N)) for n in xrange(N))
# Numpy's eigh(a, 'L') (eigh(a, 'U')) is a function of tril(a) # Numpy's eigh(a, 'L') (eigh(a, 'U')) is a function of tril(a)
...@@ -1053,4 +1057,3 @@ class EighGrad(Op): ...@@ -1053,4 +1057,3 @@ class EighGrad(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论