Flake8 format

上级 6b8db8e8
......@@ -74,12 +74,14 @@ def attach_cusolver_handle_to_context(ctx):
with ctx:
ctx.cusolver_handle = cusolver.cusolverDnCreate()
def attach_cublas_handle_to_context(ctx):
handle = getattr(ctx, 'cublas_handle', None)
if handle is None:
with ctx:
ctx.cublas_handle = cublas.cublasCreate()
# it is a subset of all cases available in slinalg's MATRIX_STRUCTURE
MATRIX_STRUCTURES_SOLVE = (
'general',
......@@ -243,6 +245,7 @@ class GpuCusolverSolve(Op):
z[0] = b
class GpuCublasTriangularSolve(Op):
"""
CUBLAS GPU Triangular Solve Op.
......@@ -357,6 +360,7 @@ class GpuCublasTriangularSolve(Op):
x[0] = b
def gpu_solve(A, b, A_structure='general', trans='N'):
if A_structure == 'lower':
return GpuCublasTriangularSolve(True, trans)(A, b)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论