提交 b1175d1e authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Add float16 support for cholesky, qr and eigh

上级 2906e215
...@@ -2149,7 +2149,12 @@ def local_inplace_gpu_cholesky(node): ...@@ -2149,7 +2149,12 @@ def local_inplace_gpu_cholesky(node):
def local_gpu_magma_cholesky(op, context_name, inputs, outputs): def local_gpu_magma_cholesky(op, context_name, inputs, outputs):
if not config.magma.enabled: if not config.magma.enabled:
return return
return GpuMagmaCholesky(lower=op.lower, inplace=op.destructive) if inputs[0].dtype not in ['float16', 'float32']:
return
op = GpuMagmaCholesky(lower=op.lower, inplace=op.destructive)
if inputs[0].dtype == 'float16':
return op(inputs[0].astype('float32')).astype('float16')
return op
lifter = op_lifter([slinalg.Cholesky])(local_gpu_magma_cholesky) lifter = op_lifter([slinalg.Cholesky])(local_gpu_magma_cholesky)
matrix_ops_db.register("local_gpu_magma_cholesky", lifter, matrix_ops_db.register("local_gpu_magma_cholesky", lifter,
'gpuarray', 'fast_compile', 'fast_run', 'magma', 'gpuarray', 'fast_compile', 'fast_run', 'magma',
...@@ -2174,7 +2179,12 @@ def local_inplace_gpu_magma_cholesky(node): ...@@ -2174,7 +2179,12 @@ def local_inplace_gpu_magma_cholesky(node):
def local_gpu_magma_qr(op, context_name, inputs, outputs): def local_gpu_magma_qr(op, context_name, inputs, outputs):
if not config.magma.enabled or op.mode != 'reduced': if not config.magma.enabled or op.mode != 'reduced':
return return
return GpuMagmaQR() if inputs[0].dtype not in ['float16', 'float32']:
return
op = GpuMagmaQR(complete=True)
if inputs[0].dtype == 'float16':
return op(inputs[0].astype('float32')).astype('float16')
return op
@register_opt('magma', 'fast_compile') @register_opt('magma', 'fast_compile')
...@@ -2183,7 +2193,12 @@ def local_gpu_magma_qr(op, context_name, inputs, outputs): ...@@ -2183,7 +2193,12 @@ def local_gpu_magma_qr(op, context_name, inputs, outputs):
def local_gpu_magma_qr_incomplete(op, context_name, inputs, outputs): def local_gpu_magma_qr_incomplete(op, context_name, inputs, outputs):
if not config.magma.enabled: if not config.magma.enabled:
return return
return GpuMagmaQR(complete=False) if inputs[0].dtype not in ['float16', 'float32']:
return
op = GpuMagmaQR(complete=False)
if inputs[0].dtype == 'float16':
return op(inputs[0].astype('float32')).astype('float16')
return op
# Matrix inverse # Matrix inverse
...@@ -2215,7 +2230,12 @@ def local_inplace_gpu_magma_matrix_inverse(node): ...@@ -2215,7 +2230,12 @@ def local_inplace_gpu_magma_matrix_inverse(node):
def local_gpu_magma_eigh(op, context_name, inputs, outputs): def local_gpu_magma_eigh(op, context_name, inputs, outputs):
if not config.magma.enabled: if not config.magma.enabled:
return return
return GpuMagmaEigh(UPLO=op.UPLO, compute_v=True) if inputs[0].dtype not in ['float16', 'float32']:
return
op = GpuMagmaEigh(UPLO=op.UPLO, compute_v=True)
if inputs[0].dtype == 'float16':
return op(inputs[0].astype('float32')).astype('float16')
return op
# Singular Value Decomposition # Singular Value Decomposition
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论