提交 2475957f authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2319 from rmcgibbo/expm

[ENH] Matrix exponential op and gradient (round 2!)
...@@ -2,6 +2,7 @@ import logging ...@@ -2,6 +2,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import numpy import numpy
import warnings
from theano.gof import Op, Apply from theano.gof import Op, Apply
...@@ -214,7 +215,7 @@ class Eigvalsh(Op): ...@@ -214,7 +215,7 @@ class Eigvalsh(Op):
"Scipy not available. Scipy is needed for the Eigvalsh op") "Scipy not available. Scipy is needed for the Eigvalsh op")
if b == theano.tensor.NoneConst: if b == theano.tensor.NoneConst:
a = as_tensor_variable(a) a = as_tensor_variable(a)
assert a.ndim == 2 assert a.ndim == 2
out_dtype = theano.scalar.upcast(a.dtype) out_dtype = theano.scalar.upcast(a.dtype)
...@@ -276,7 +277,7 @@ class EigvalshGrad(Op): ...@@ -276,7 +277,7 @@ class EigvalshGrad(Op):
"Scipy not available. Scipy is needed for the GEigvalsh op") "Scipy not available. Scipy is needed for the GEigvalsh op")
a = as_tensor_variable(a) a = as_tensor_variable(a)
b = as_tensor_variable(b) b = as_tensor_variable(b)
gw = as_tensor_variable(gw) gw = as_tensor_variable(gw)
assert a.ndim == 2 assert a.ndim == 2
assert b.ndim == 2 assert b.ndim == 2
assert gw.ndim == 1 assert gw.ndim == 1
...@@ -336,3 +337,61 @@ def kron(a, b): ...@@ -336,3 +337,61 @@ def kron(a, b):
o.shape[1] * o.shape[3]) + o.shape[1] * o.shape[3]) +
tuple([o.shape[i] for i in range(4, o.ndim)])) tuple([o.shape[i] for i in range(4, o.ndim)]))
return o return o
class Expm(Op):
"""Compute the matrix exponential of a square array
"""
def make_node(self, A):
assert imported_scipy, (
"Scipy not available. Scipy is needed for the Expm op")
A = as_tensor_variable(A)
assert A.ndim == 2
expm = theano.tensor.matrix(dtype=A.dtype)
return Apply(self, [A,], [expm,])
def perform(self, node, (A,), (expm,)):
expm[0] = scipy.linalg.expm(A)
def grad(self, (A,), (g_out,)):
return [ExpmGrad()(A, g_out)]
def infer_shape(self, node, shapes):
return [shapes[0]]
class ExpmGrad(Op):
"""Gradient of the matrix exponential of a square array.
"""
def make_node(self, A, gw):
assert imported_scipy, (
"Scipy not available. Scipy is needed for the Expm op")
A = as_tensor_variable(A)
assert A.ndim == 2
out = theano.tensor.matrix(dtype=A.dtype)
return Apply(self, [A, gw], [out,])
def infer_shape(self, node, shapes):
return [shapes[0]]
def perform(self, node, (A, gA), (out,)):
# Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
# Kind of... You need to do some algebra from there to arrive at
# this expression.
w, V = scipy.linalg.eig(A, right=True)
U = scipy.linalg.inv(V).T
exp_w = numpy.exp(w)
X = numpy.subtract.outer(exp_w, exp_w) / numpy.subtract.outer(w, w)
numpy.fill_diagonal(X, exp_w)
Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T)
with warnings.catch_warnings():
warnings.simplefilter("ignore", numpy.ComplexWarning)
out[0] = Y.astype(A.dtype)
expm = Expm()
...@@ -20,7 +20,8 @@ from theano.tensor.slinalg import ( Cholesky, ...@@ -20,7 +20,8 @@ from theano.tensor.slinalg import ( Cholesky,
solve, solve,
Eigvalsh, Eigvalsh,
EigvalshGrad, EigvalshGrad,
eigvalsh eigvalsh,
expm
) )
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
...@@ -189,7 +190,7 @@ class test_Solve(utt.InferShapeTester): ...@@ -189,7 +190,7 @@ class test_Solve(utt.InferShapeTester):
dtype=config.floatX)], dtype=config.floatX)],
self.op_class, self.op_class,
warn=False) warn=False)
def test_solve_correctness(self): def test_solve_correctness(self):
if not imported_scipy: if not imported_scipy:
raise SkipTest("Scipy needed for the Cholesky op.") raise SkipTest("Scipy needed for the Cholesky op.")
...@@ -227,3 +228,53 @@ class test_Solve(utt.InferShapeTester): ...@@ -227,3 +228,53 @@ class test_Solve(utt.InferShapeTester):
U_val = scipy.linalg.cholesky(A_val, lower=False) U_val = scipy.linalg.cholesky(A_val, lower=False)
assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False), assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val)) upper_solve_func(U_val, b_val))
def test_expm():
if not imported_scipy:
raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(5, 5).astype(config.floatX)
ref = scipy.linalg.expm(A)
x = tensor.matrix()
m = expm(x)
expm_f = function([x], m)
val = expm_f(A)
numpy.testing.assert_array_almost_equal(val, ref)
def test_expm_grad_1():
# with symmetric matrix (real eigenvectors)
if not imported_scipy:
raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(5, 5).astype(config.floatX)
A = A + A.T
tensor.verify_grad(expm, [A,], rng=rng)
def test_expm_grad_2():
# with non-symmetric matrix with real eigenspecta
if not imported_scipy:
raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(5, 5).astype(config.floatX)
w = (rng.randn(5).astype(config.floatX))**2
A = (numpy.diag(w**0.5)).dot(A + A.T).dot(numpy.diag(w**(-0.5)))
assert not numpy.allclose(A, A.T)
tensor.verify_grad(expm, [A,], rng=rng)
def test_expm_grad_3():
# with non-symmetric matrix (complex eigenvectors)
if not imported_scipy:
raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(5, 5).astype(config.floatX)
tensor.verify_grad(expm, [A,], rng=rng)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论