提交 80fc28a5 authored 作者: wonghang's avatar wonghang

Add float64 support to opt.py

上级 627b63af
......@@ -66,6 +66,26 @@ if cusolver_available:
ldb, int(devInfo))
cusolver.cusolverCheckStatus(status)
# DPOTRS
# TODO: Are they still missing in skucda?
cusolver._libcusolver.cusolverDnDpotrs.restype = int
cusolver._libcusolver.cusolverDnDpotrs.argtypes = [cusolver.ctypes.c_void_p,
cusolver.ctypes.c_int,
cusolver.ctypes.c_int,
cusolver.ctypes.c_int,
cusolver.ctypes.c_void_p,
cusolver.ctypes.c_int,
cusolver.ctypes.c_void_p,
cusolver.ctypes.c_int,
cusolver.ctypes.c_void_p]
def cusolverDnDpotrs(handle, uplo, n, nrhs, A, lda,
B, ldb, devInfo):
status = cusolver._libcusolver.cusolverDnDpotrs(handle, uplo, n, nrhs,
int(A), lda, int(B),
ldb, int(devInfo))
cusolver.cusolverCheckStatus(status)
def attach_cusolver_handle_to_context(ctx):
handle = getattr(ctx, 'cusolver_handle', None)
......@@ -389,6 +409,11 @@ def gpu_solve(A, b, A_structure='general', trans='N'):
return GpuCusolverSolve(A_structure, trans)(A, b)
# added these to make the module consistent to theano/tensor/slinalg.py
def gpu_solve_lower_triangular(A,b):
return GpuCublasTriangularSolve(True,'N')(A,b)
def gpu_solve_upper_triangular(A,b):
return GpuCublasTriangularSolve(False,'N')(A,b)
class GpuCholesky(Op):
"""
......
......@@ -2583,7 +2583,7 @@ def local_gpua_images2neibs(op, context_name, inputs, outputs):
@op_lifter([slinalg.Solve])
@register_opt2([theano.tensor.slinalg.Solve], 'fast_compile')
def local_gpu_solve(op, context_name, inputs, outputs):
if inputs[0].dtype not in ['float16', 'float32']:
if inputs[0].dtype not in ['float16', 'float32','float64']:
return
if op.A_structure not in MATRIX_STRUCTURES_SOLVE:
return
......@@ -2617,7 +2617,7 @@ def local_inplace_gpu_solve(node):
def local_gpu_cholesky(op, context_name, inputs, outputs):
if not cusolver_available:
return
if inputs[0].dtype not in ['float16', 'float32']:
if inputs[0].dtype not in ['float16', 'float32', 'float64']:
return
op = GpuCholesky(lower=op.lower, inplace=op.destructive)
if inputs[0].dtype == 'float16':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论