提交 6fda60f1 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4148 from matt-graham/cholesky_grad

Symbolic Cholesky gradient implementation
...@@ -63,8 +63,43 @@ class Cholesky(Op): ...@@ -63,8 +63,43 @@ class Cholesky(Op):
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype) z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
def grad(self, inputs, gradients): def grad(self, inputs, gradients):
return [CholeskyGrad(self.lower)(inputs[0], self(inputs[0]), """
gradients[0])] Cholesky decomposition reverse-mode gradient update.
Symbolic expression for reverse-mode Cholesky gradient taken from [0]_
References
----------
.. [0] I. Murray, "Differentiation of the Cholesky decomposition",
http://arxiv.org/abs/1602.07527
"""
x = inputs[0]
dz = gradients[0]
chol_x = self(x)
# deal with upper triangular by converting to lower triangular
if not self.lower:
chol_x = chol_x.T
dz = dz.T
def tril_and_halve_diagonal(mtx):
"""Extracts lower triangle of square matrix and halves diagonal."""
return tensor.tril(mtx) - tensor.diag(tensor.diagonal(mtx) / 2.)
def conjugate_solve_triangular(outer, inner):
"""Computes L^{-T} P L^{-1} for lower-triangular L."""
return solve_upper_triangular(
outer.T, solve_upper_triangular(outer.T, inner.T).T)
s = conjugate_solve_triangular(
chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz)))
if self.lower:
return [tensor.tril(s + s.T) - tensor.diag(tensor.diagonal(s))]
else:
return [tensor.triu(s + s.T) - tensor.diag(tensor.diagonal(s))]
cholesky = Cholesky() cholesky = Cholesky()
...@@ -195,8 +230,9 @@ class Solve(Op): ...@@ -195,8 +230,9 @@ class Solve(Op):
return [(rows, cols)] return [(rows, cols)]
solve = Solve() # general solve solve = Solve() # general solve
# lower and upper triangular solves
# TODO : SolveTriangular solve_lower_triangular = Solve(A_structure='lower_triangular', lower=True)
solve_upper_triangular = Solve(A_structure='upper_triangular', lower=False)
# TODO: Optimizations to replace multiplication by matrix inverse # TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten) # with solve() Op (still unwritten)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论