提交 4cda7c7e authored 作者: Robert McGibbon's avatar Robert McGibbon

change the easy stuff

上级 1ed1c595
...@@ -1128,12 +1128,7 @@ class GEigvalsh(Op): ...@@ -1128,12 +1128,7 @@ class GEigvalsh(Op):
return Apply(self, [a, b], [w]) return Apply(self, [a, b], [w])
def perform(self, node, (a, b), (w,)): def perform(self, node, (a, b), (w,)):
try: w[0] = scipy.linalg.eigvalsh(a=a, b=b, lower=self.lower)
w[0] = scipy.linalg.eigvalsh(a=a, b=b, lower=self.lower)
except numpy.linalg.LinAlgError:
logger.debug('Failed to find generalized eigs of %s' % (
node.inputs[0]))
raise
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
a, b = inputs a, b = inputs
...@@ -1183,9 +1178,7 @@ class GEigvalshGrad(Op): ...@@ -1183,9 +1178,7 @@ class GEigvalshGrad(Op):
return Apply(self, [a, b, gw], [out1, out2]) return Apply(self, [a, b, gw], [out1, out2])
def perform(self, node, (a, b, gw), outputs): def perform(self, node, (a, b, gw), outputs):
N = a.shape[0] w, v = scipy.linalg.eigh(a, b, lower=self.lower)
w, v = scipy.linalg.eigh(a, b, lower=True)
gA = v.dot(numpy.diag(gw).dot(v.T)) gA = v.dot(numpy.diag(gw).dot(v.T))
gB = - v.dot(numpy.diag(gw*w).dot(v.T)) gB = - v.dot(numpy.diag(gw*w).dot(v.T))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论