提交 9e03a2e2 authored 作者: Matt Graham's avatar Matt Graham

Initial implementation of Solve grad and unit test.

上级 4e7f550d
......@@ -193,6 +193,40 @@ class Solve(Op):
cols = Bshape[1] # b is a Matrix
return [(rows, cols)]
def grad(self, inputs, output_gradients):
"""
Reverse-mode gradient updates for matrix solve operation c = A \ b.
Symbolic expression for updates taken from [1]_.
References
----------
..[1] M. B. Giles, "An extended collection of matrix derivative results
for forward and reverse mode automatic differentiation",
http://eprints.maths.ox.ac.uk/1079/
"""
A, b = inputs
c = self(A, b)
c_bar = output_gradients[0]
trans_map = {
'lower_triangular': 'upper_triangular',
'upper_triangular': 'lower_triangular'
}
trans_solve_op = Solve(
# update A_structure and lower to account for a transpose operation
A_structure=trans_map.get(self.A_structure, self.A_structure),
lower=not self.lower
)
b_bar = trans_solve_op(A.T, c_bar)
# force outer product if vector second input
A_bar = -tensor.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
if self.A_structure == 'lower_triangular':
A_bar = tensor.tril(A_bar)
elif self.A_structure == 'upper_triangular':
A_bar = tensor.triu(A_bar)
return [A_bar, b_bar]
solve = Solve() # general solve
# TODO : SolveTriangular
......
......@@ -164,7 +164,7 @@ class test_Solve(utt.InferShapeTester):
def test_infer_shape(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Cholesky op.")
raise SkipTest("Scipy needed for the Solve op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = theano.tensor.matrix()
b = theano.tensor.matrix()
......@@ -192,7 +192,7 @@ class test_Solve(utt.InferShapeTester):
def test_solve_correctness(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Cholesky op.")
raise SkipTest("Scipy needed for the Cholesky and Solve ops.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = theano.tensor.matrix()
b = theano.tensor.matrix()
......@@ -210,7 +210,7 @@ class test_Solve(utt.InferShapeTester):
upper_solve_func = theano.function([U, b], y_upper)
b_val = numpy.asarray(rng.rand(5, 1), dtype=config.floatX)
# 1-test general case
A_val = numpy.asarray(rng.rand(5, 5), dtype=config.floatX)
# positive definite matrix:
......@@ -228,6 +228,33 @@ class test_Solve(utt.InferShapeTester):
assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val))
def verify_solve_grad(self, m, n, A_structure, lower, rng):
A_val = rng.normal(size=(m, m))
if A_structure == 'lower_triangular':
A_val = numpy.tril(A_val)
elif A_structure == 'upper_triangular':
A_val = numpy.triu(A_val)
if n is None:
b_val = rng.normal(size=m).astype(config.floatX)
else:
b_val = rng.normal(size=(m, n)).astype(config.floatX)
eps = None
if config.floatX == "float64":
eps = 2e-8
solve_op = Solve(A_structure=A_structure, lower=lower)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
def test_solve_grad(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.")
rng = numpy.random.RandomState(utt.fetch_seed())
structures = ['general', 'lower_triangular', 'upper_triangular']
for A_structure in structures:
lower = (A_structure == 'lower_triangular')
self.verify_solve_grad(5, None, A_structure, lower, rng)
self.verify_solve_grad(6, 1, A_structure, lower, rng)
self.verify_solve_grad(4, 3, A_structure, lower, rng)
def test_expm():
if not imported_scipy:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论