提交 70794024 authored 作者: Robert McGibbon's avatar Robert McGibbon

Found the right expression

上级 193cb321
......@@ -2,6 +2,7 @@ import logging
logger = logging.getLogger(__name__)
import numpy
import warnings
from theano.gof import Op, Apply
......@@ -371,20 +372,20 @@ class ExpmGrad(Op):
return [shapes[0]]
def perform(self, node, (A, gA), (out,)):
if not numpy.allclose(A, A.T):
raise NotImplementedError(
"ExpmGrad is only implemented for symmetric matrices")
# Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
w, U = scipy.linalg.eigh(A)
# Kind of... You need to do some algebra from there to arrive at
# this expression.
w, V = scipy.linalg.eig(A, right=True)
U = scipy.linalg.inv(V).T
G = (U.T).dot(gA).dot(U)
exp_w = numpy.exp(w)
V = numpy.subtract.outer(exp_w, exp_w) / numpy.subtract.outer(w, w)
numpy.fill_diagonal(V, exp_w)
V = numpy.multiply(V, G, V)
X = numpy.subtract.outer(exp_w, exp_w) / numpy.subtract.outer(w, w)
numpy.fill_diagonal(X, exp_w)
Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T)
out[0] = (U.dot(V).dot(U.T)).astype(A.dtype)
with warnings.catch_warnings():
warnings.simplefilter("ignore", numpy.ComplexWarning)
out[0] = Y.astype(A.dtype)
expm = Expm()
......@@ -213,7 +213,30 @@ def test_expm_grad_1():
if not imported_scipy:
raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(3, 3).astype(config.floatX)
A = rng.randn(5, 5).astype(config.floatX)
A = A + A.T
tensor.verify_grad(expm, [A,], rng=rng)
def test_expm_grad_2():
# with non-symmetric matrix with real eigenspecta
if not imported_scipy:
raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(5, 5).astype(config.floatX)
w = (rng.randn(5).astype(config.floatX))**2
A = (numpy.diag(w**0.5)).dot(A + A.T).dot(numpy.diag(w**(-0.5)))
assert not numpy.allclose(A, A.T)
tensor.verify_grad(expm, [A,], rng=rng)
def test_expm_grad_3():
# with non-symmetric matrix (complex eigenvectors)
if not imported_scipy:
raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(5, 5).astype(config.floatX)
tensor.verify_grad(expm, [A,], rng=rng)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论