提交 53799c53 authored 作者: wonghang's avatar wonghang

Add L_op for GpuCusolverSolve

上级 b1f09ae4
...@@ -286,31 +286,12 @@ class GpuCusolverSolve(Op): ...@@ -286,31 +286,12 @@ class GpuCusolverSolve(Op):
A, b = inputs A, b = inputs
c = outputs[0] c = outputs[0]
c_bar = output_gradients[0] c_bar = output_gradients[0]
trans_map = { # FIXME: triangular structure would use GpuCublasTriangularsolve?
'lower_triangular': 'upper_triangular', # no need to handle A_structure like slinalg.py?
'upper_triangular': 'lower_triangular', trans_solve_op = GpuCusolverSolve('general')
}
trans_map2 = {
'N': 'T',
'T': 'N',
}
# if self.A_structure == 'lower_triangular':
# trans_solve_op = GpuCublasTriangularSolve(lower=False)
# elif self.A_structure == 'upper_triangular':
# trans_solve_op = GpuCublasTriangularSolve(lower=True)
# else:
trans_solve_op = GpuCusolverSolve(
# update A_structure and lower to account for a transpose operation
A_structure=trans_map.get(self.A_structure,self.A_structure),
# trans=trans_map2[self.trans],
)
b_bar = trans_solve_op(A.T, c_bar) b_bar = trans_solve_op(A.T, c_bar)
# force outer product if vector second input # force outer product if vector second input
A_bar = -tensor.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T) A_bar = -tensor.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
if self.A_structure == 'lower_triangular':
A_bar = tensor.tril(A_bar)
elif self.A_structure == 'upper_triangular':
A_bar = tensor.triu(A_bar)
return [A_bar, b_bar] return [A_bar, b_bar]
class GpuCublasTriangularSolve(Op): class GpuCublasTriangularSolve(Op):
......
...@@ -145,8 +145,7 @@ class TestCusolver(unittest.TestCase): ...@@ -145,8 +145,7 @@ class TestCusolver(unittest.TestCase):
def test_solve_grad(self): def test_solve_grad(self):
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
# structures = ['general', 'lower_triangular', 'upper_triangular'] structures = ['general', 'lower_triangular', 'upper_triangular']
structures = ['general']
for A_structure in structures: for A_structure in structures:
lower = (A_structure == 'lower_triangular') lower = (A_structure == 'lower_triangular')
# self.verify_solve_grad(5, None, A_structure, lower, rng) # self.verify_solve_grad(5, None, A_structure, lower, rng)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论