Flake8 format

上级 6b8db8e8
...@@ -74,12 +74,14 @@ def attach_cusolver_handle_to_context(ctx): ...@@ -74,12 +74,14 @@ def attach_cusolver_handle_to_context(ctx):
with ctx: with ctx:
ctx.cusolver_handle = cusolver.cusolverDnCreate() ctx.cusolver_handle = cusolver.cusolverDnCreate()
def attach_cublas_handle_to_context(ctx): def attach_cublas_handle_to_context(ctx):
handle = getattr(ctx, 'cublas_handle', None) handle = getattr(ctx, 'cublas_handle', None)
if handle is None: if handle is None:
with ctx: with ctx:
ctx.cublas_handle = cublas.cublasCreate() ctx.cublas_handle = cublas.cublasCreate()
# it is a subset of all cases available in slinalg's MATRIX_STRUCTURE # it is a subset of all cases available in slinalg's MATRIX_STRUCTURE
MATRIX_STRUCTURES_SOLVE = ( MATRIX_STRUCTURES_SOLVE = (
'general', 'general',
...@@ -243,6 +245,7 @@ class GpuCusolverSolve(Op): ...@@ -243,6 +245,7 @@ class GpuCusolverSolve(Op):
z[0] = b z[0] = b
class GpuCublasTriangularSolve(Op): class GpuCublasTriangularSolve(Op):
""" """
CUBLAS GPU Triangular Solve Op. CUBLAS GPU Triangular Solve Op.
...@@ -357,6 +360,7 @@ class GpuCublasTriangularSolve(Op): ...@@ -357,6 +360,7 @@ class GpuCublasTriangularSolve(Op):
x[0] = b x[0] = b
def gpu_solve(A, b, A_structure='general', trans='N'): def gpu_solve(A, b, A_structure='general', trans='N'):
if A_structure == 'lower': if A_structure == 'lower':
return GpuCublasTriangularSolve(True, trans)(A, b) return GpuCublasTriangularSolve(True, trans)(A, b)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论