提交 7176a3d1 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4153 from matt-graham/solve_grad

Gradient implementation for Solve op
......@@ -229,6 +229,40 @@ class Solve(Op):
cols = Bshape[1] # b is a Matrix
return [(rows, cols)]
def grad(self, inputs, output_gradients):
"""
Reverse-mode gradient updates for matrix solve operation c = A \ b.
Symbolic expression for updates taken from [1]_.
References
----------
..[1] M. B. Giles, "An extended collection of matrix derivative results
for forward and reverse mode automatic differentiation",
http://eprints.maths.ox.ac.uk/1079/
"""
A, b = inputs
c = self(A, b)
c_bar = output_gradients[0]
trans_map = {
'lower_triangular': 'upper_triangular',
'upper_triangular': 'lower_triangular'
}
trans_solve_op = Solve(
# update A_structure and lower to account for a transpose operation
A_structure=trans_map.get(self.A_structure, self.A_structure),
lower=not self.lower
)
b_bar = trans_solve_op(A.T, c_bar)
# force outer product if vector second input
A_bar = -tensor.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
if self.A_structure == 'lower_triangular':
A_bar = tensor.tril(A_bar)
elif self.A_structure == 'upper_triangular':
A_bar = tensor.triu(A_bar)
return [A_bar, b_bar]
solve = Solve() # general solve
# lower and upper triangular solves
solve_lower_triangular = Solve(A_structure='lower_triangular', lower=True)
......
......@@ -165,7 +165,7 @@ class test_Solve(utt.InferShapeTester):
def test_infer_shape(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Cholesky op.")
raise SkipTest("Scipy needed for the Solve op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = theano.tensor.matrix()
b = theano.tensor.matrix()
......@@ -193,7 +193,7 @@ class test_Solve(utt.InferShapeTester):
def test_solve_correctness(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Cholesky op.")
raise SkipTest("Scipy needed for the Cholesky and Solve ops.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = theano.tensor.matrix()
b = theano.tensor.matrix()
......@@ -211,7 +211,7 @@ class test_Solve(utt.InferShapeTester):
upper_solve_func = theano.function([U, b], y_upper)
b_val = numpy.asarray(rng.rand(5, 1), dtype=config.floatX)
# 1-test general case
A_val = numpy.asarray(rng.rand(5, 5), dtype=config.floatX)
# positive definite matrix:
......@@ -229,6 +229,39 @@ class test_Solve(utt.InferShapeTester):
assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val))
def verify_solve_grad(self, m, n, A_structure, lower, rng):
# ensure diagonal elements of A relatively large to avoid numerical
# precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 +
numpy.eye(m)).astype(config.floatX)
if A_structure == 'lower_triangular':
A_val = numpy.tril(A_val)
elif A_structure == 'upper_triangular':
A_val = numpy.triu(A_val)
if n is None:
b_val = rng.normal(size=m).astype(config.floatX)
else:
b_val = rng.normal(size=(m, n)).astype(config.floatX)
eps = None
if config.floatX == "float64":
eps = 2e-8
solve_op = Solve(A_structure=A_structure, lower=lower)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
def test_solve_grad(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.")
rng = numpy.random.RandomState(utt.fetch_seed())
structures = ['general', 'lower_triangular', 'upper_triangular']
for A_structure in structures:
lower = (A_structure == 'lower_triangular')
self.verify_solve_grad(5, None, A_structure, lower, rng)
self.verify_solve_grad(6, 1, A_structure, lower, rng)
self.verify_solve_grad(4, 3, A_structure, lower, rng)
# lower should have no effect for A_structure == 'general' so also
# check lower=True case
self.verify_solve_grad(4, 3, 'general', lower=True, rng=rng)
def test_expm():
if not imported_scipy:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论