提交 1c8f8d6a authored 作者: Matt Graham's avatar Matt Graham

Adding symbolic definition of Cholesky gradient.

上级 d94cdf4f
...@@ -62,8 +62,27 @@ class Cholesky(Op): ...@@ -62,8 +62,27 @@ class Cholesky(Op):
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype) z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
def grad(self, inputs, gradients): def grad(self, inputs, gradients):
return [CholeskyGrad(self.lower)(inputs[0], self(inputs[0]),
gradients[0])] x = inputs[0]
dz = gradients[0]
chol_x = self(x)
chol_x = theano.printing.Print('Cholesky:')(chol_x)
def tril_and_halve_diagonal(mtx):
"""Extracts lower triangle of square matrix and halves diagonal."""
return tensor.tril(mtx) - tensor.diag(tensor.diagonal(mtx) / 2.)
def conjugate_solve_triangular(outer, inner):
"""Computes P^{-T} Q P^{-1} for lower-triangular P."""
return solve_upper_triangular(
outer.T, solve_upper_triangular(outer.T, inner.T).T)
s = conjugate_solve_triangular(
chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz)))
s = theano.printing.Print('S:')(s)
return [tensor.tril(s + s.T) - tensor.diag(tensor.diagonal(s))]
cholesky = Cholesky() cholesky = Cholesky()
...@@ -194,8 +213,9 @@ class Solve(Op): ...@@ -194,8 +213,9 @@ class Solve(Op):
return [(rows, cols)] return [(rows, cols)]
solve = Solve() # general solve solve = Solve() # general solve
# lower and upper triangular solves
# TODO : SolveTriangular solve_lower_triangular = Solve(A_structure='lower_triangular', lower=True)
solve_upper_triangular = Solve(A_structure='upper_triangular', lower=False)
# TODO: Optimizations to replace multiplication by matrix inverse # TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten) # with solve() Op (still unwritten)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论