提交 4a1a016d authored 作者: wonghang's avatar wonghang

Add L_op to GpuCublasTriangularSolve

上级 53799c53
...@@ -290,7 +290,6 @@ class GpuCusolverSolve(Op): ...@@ -290,7 +290,6 @@ class GpuCusolverSolve(Op):
# no need to handle A_structure like slinalg.py? # no need to handle A_structure like slinalg.py?
trans_solve_op = GpuCusolverSolve('general') trans_solve_op = GpuCusolverSolve('general')
b_bar = trans_solve_op(A.T, c_bar) b_bar = trans_solve_op(A.T, c_bar)
# force outer product if vector second input
A_bar = -tensor.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T) A_bar = -tensor.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
return [A_bar, b_bar] return [A_bar, b_bar]
...@@ -414,6 +413,24 @@ class GpuCublasTriangularSolve(Op): ...@@ -414,6 +413,24 @@ class GpuCublasTriangularSolve(Op):
x[0] = b x[0] = b
def L_op(self, inputs, outputs, output_gradients):
"""
Modified from theano/tensor/slinalg.py
"""
A, b = inputs
c = outputs[0]
c_bar = output_gradients[0]
trans_solve_op = GpuCublasTriangularSolve(not self.lower)
b_bar = trans_solve_op(A.T, c_bar)
A_bar = -tensor.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
if self.lower:
A_bar = tensor.tril(A_bar)
else:
A_bar = tensor.triu(A_bar)
return [A_bar, b_bar]
def gpu_solve(A, b, A_structure='general', trans='N'): def gpu_solve(A, b, A_structure='general', trans='N'):
if A_structure == 'lower': if A_structure == 'lower':
return GpuCublasTriangularSolve(True, trans)(A, b) return GpuCublasTriangularSolve(True, trans)(A, b)
......
...@@ -140,7 +140,13 @@ class TestCusolver(unittest.TestCase): ...@@ -140,7 +140,13 @@ class TestCusolver(unittest.TestCase):
eps = None eps = None
if config.floatX == "float64": if config.floatX == "float64":
eps = 2e-8 eps = 2e-8
solve_op = GpuCusolverSolve(A_structure=A_structure)
if A_structure == 'lower_triangular':
solve_op = GpuCublasTriangularSolve(lower=False)
elif A_structure == 'lower_triangular':
solve_op = GpuCublasTriangularSolve(lower=True)
else:
solve_op = GpuCusolverSolve(A_structure="general")
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
def test_solve_grad(self): def test_solve_grad(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论