提交 1394b39b authored 作者: Robert McGibbon's avatar Robert McGibbon

Clean stuff up, no matrix inverse

上级 e412b2bd
import logging import logging
import warnings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import numpy import numpy
...@@ -370,17 +371,21 @@ class ExpmGrad(Op): ...@@ -370,17 +371,21 @@ class ExpmGrad(Op):
return [shapes[0]] return [shapes[0]]
def perform(self, node, (A, gw), (out,)): def perform(self, node, (A, gw), (out,)):
w, M = scipy.linalg.eig(A) # Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
w, UL, UR = scipy.linalg.eig(A, left=True, right=True)
G = scipy.linalg.solve(M, gw).dot(M) if numpy.linalg.norm(w - numpy.real(w)) > numpy.finfo(A.dtype).eps:
warnings.warn("ExpmGrad not correct for matrices with complex "
"eigenvalues")
G = (UR.conj().T).dot(gw).dot(UL)
exp_w = numpy.exp(w) exp_w = numpy.exp(w)
V = numpy.subtract.outer(exp_w, exp_w) / numpy.subtract.outer(w, w) V = numpy.subtract.outer(exp_w, exp_w) / numpy.subtract.outer(w, w)
V[numpy.diag_indices_from(V)] = exp_w V[numpy.diag_indices_from(V)] = exp_w
numpy.multiply(V, G, V) V = numpy.multiply(V, G, V)
Mi = scipy.linalg.inv(M) out[0] = numpy.real(UL.dot(V).dot(UR.conj().T)).astype(A.dtype)
out[0] = numpy.real(M.dot(V).dot(Mi))
def expm(A): def expm(A):
......
...@@ -224,31 +224,29 @@ def test_expm_grad_2(): ...@@ -224,31 +224,29 @@ def test_expm_grad_2():
raise SkipTest("Scipy needed for the expm op.") raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(3, 3).astype(config.floatX) A = rng.randn(3, 3).astype(config.floatX)
A = A
tensor.verify_grad(expm, [A,], rng=rng) tensor.verify_grad(expm, [A,], rng=rng)
# def test_expm_grad_3():
def test_expm_grad_3(): # if not imported_scipy:
if not imported_scipy: # raise SkipTest("Scipy needed for the expm op.")
raise SkipTest("Scipy needed for the expm op.") # from theano.gradient import grad
from theano.gradient import grad # rng = numpy.random.RandomState(utt.fetch_seed())
rng = numpy.random.RandomState(utt.fetch_seed()) # A = rng.randn(3, 3).astype(config.floatX)
A = rng.randn(3, 3).astype(config.floatX) #
# h = 1e-7
h = 1e-7 # def e(i,j):
def e(i,j): # v = numpy.zeros((3, 3))
v = numpy.zeros((3, 3)) # v[i, j] = 1
v[i, j] = 1 # return v
return v #
# x = tensor.matrix()
x = tensor.matrix() # grad_expm_f = function([x], grad(expm(x)[0,0], x))
grad_expm_f = function([x], grad(expm(x)[0,1], x)) # expm_f = function([x], expm(x)[0,0])
expm_f = function([x], expm(x)[0,1]) #
# g = lambda i, j: (expm_f(A + h*e(i,j)) - expm_f(A)) / h
g = lambda i, j: (expm_f(A + h*e(i,j)) - expm_f(A)) / h # numgrad = numpy.array([[g(i,j) for i in range(3)] for j in range(3)])
numgrad = numpy.array([[g(i,j) for i in range(3)] for j in range(3)]) #
# print(grad_expm_f(A))
print(grad_expm_f(A)) # print(numgrad)
print(numgrad) #
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论