提交 d26ae898 authored 作者: Robert McGibbon's avatar Robert McGibbon

Change name to eigvalsh

上级 4cda7c7e
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
from kron import kron from kron import kron
from ops import (cholesky, matrix_inverse, solve, from ops import (cholesky, matrix_inverse, solve,
diag, extract_diag, alloc_diag, diag, extract_diag, alloc_diag,
det, psd, eig, eigh, geigvalsh, det, psd, eig, eigh, eigvalsh,
trace, spectral_radius_bound) trace, spectral_radius_bound)
...@@ -1099,7 +1099,7 @@ class EighGrad(Op): ...@@ -1099,7 +1099,7 @@ class EighGrad(Op):
return [shapes[0]] return [shapes[0]]
class GEigvalsh(Op): class Eigvalsh(Op):
"""Generalized eigenvalues of a Hermetian positive definite eigensystem """Generalized eigenvalues of a Hermetian positive definite eigensystem
""" """
...@@ -1133,14 +1133,14 @@ class GEigvalsh(Op): ...@@ -1133,14 +1133,14 @@ class GEigvalsh(Op):
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
a, b = inputs a, b = inputs
gw, = g_outputs gw, = g_outputs
return GEigvalshGrad(self.lower)(a, b, gw) return EigvalshGrad(self.lower)(a, b, gw)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
n = shapes[0][0] n = shapes[0][0]
return [(n,)] return [(n,)]
class GEigvalshGrad(Op): class EigvalshGrad(Op):
"""Gradient of generalized eigenvalues of a Hermetian positive definite """Gradient of generalized eigenvalues of a Hermetian positive definite
eigensystem eigensystem
""" """
...@@ -1192,5 +1192,5 @@ class GEigvalshGrad(Op): ...@@ -1192,5 +1192,5 @@ class GEigvalshGrad(Op):
return [shapes[0], shapes[1]] return [shapes[0], shapes[1]]
def geigvalsh(a, b, lower=True): def eigvalsh(a, b, lower=True):
return GEigvalsh(lower)(a, b) return Eigvalsh(lower)(a, b)
...@@ -32,7 +32,7 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -32,7 +32,7 @@ from theano.sandbox.linalg.ops import (cholesky,
Eig, Eig,
inv_as_solve, inv_as_solve,
) )
from theano.sandbox.linalg import eig, eigh, geigvalsh from theano.sandbox.linalg import eig, eigh, eigvalsh
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
...@@ -575,29 +575,28 @@ def test_matrix_inverse_solve(): ...@@ -575,29 +575,28 @@ def test_matrix_inverse_solve():
assert isinstance(out.owner.op, Solve) assert isinstance(out.owner.op, Solve)
def test_geigvalsh(): def test_eigvalsh():
if not imported_scipy: if not imported_scipy:
raise SkipTest("Scipy needed for the geigvalsh op.") raise SkipTest("Scipy needed for the geigvalsh op.")
import scipy.linalg import scipy.linalg
A = theano.tensor.dmatrix('a') A = theano.tensor.dmatrix('a')
B = theano.tensor.dmatrix('b') B = theano.tensor.dmatrix('b')
f = function([A, B], geigvalsh(A, B)) f = function([A, B], eigvalsh(A, B))
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
a = rng.randn(5, 5) a = rng.randn(5, 5)
a = a + a.T a = a + a.T
b = 10 * numpy.eye(5, 5) + rng.randn(5, 5) for b in [10 * numpy.eye(5, 5) + rng.randn(5, 5), None]:
w = f(a, b)
w = f(a, b) refw = scipy.linalg.eigvalsh(a, b)
refw = scipy.linalg.eigvalsh(a, b) numpy.testing.assert_array_almost_equal(w, refw)
numpy.testing.assert_array_almost_equal(w, refw)
def test_geigvalsh_grad(): def test_eigvalsh_grad():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
a = rng.randn(5, 5) a = rng.randn(5, 5)
a = a + a.T a = a + a.T
b = 10 * numpy.eye(5, 5) + rng.randn(5, 5) b = 10 * numpy.eye(5, 5) + rng.randn(5, 5)
tensor.verify_grad(lambda a, b: geigvalsh(a, b).dot([1, 2, 3, 4, 5]), tensor.verify_grad(lambda a, b: eigvalsh(a, b).dot([1, 2, 3, 4, 5]),
[a, b], rng=numpy.random) [a, b], rng=numpy.random)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论