提交 8a9caf2e authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Add gpu magma qr optimizations

上级 e8b2f369
......@@ -74,7 +74,8 @@ from .subtensor import (GpuIncSubtensor, GpuSubtensor,
from .opt_util import alpha_merge, output_merge, pad_dims, unpad_dims
from .reduction import GpuMaxAndArgmax
from .linalg import (GpuCusolverSolve, MATRIX_STRUCTURES_SOLVE, GpuCholesky,
cusolver_available, GpuMagmaMatrixInverse, gpu_svd)
cusolver_available, GpuMagmaMatrixInverse, GpuMagmaSVD,
GpuMagmaCholesky, GpuMagmaQR)
_logger = logging.getLogger("theano.gpuarray.opt")
......@@ -2135,7 +2136,7 @@ register_opt2([slinalg.Solve], 'fast_compile', name='matrix_ops_db2')(matrix_ops
@register_inplace()
@local_optimizer([GpuCholesky], inplace=True)
def local_inplace_cholesky(node):
def local_inplace_gpu_cholesky(node):
if isinstance(node.op, GpuCholesky) and not node.op.inplace:
return [node.op.clone_inplace()(*node.inputs)]
......@@ -2161,11 +2162,30 @@ def local_inplace_gpu_magma_cholesky(node):
return [node.op.clone_inplace()(*node.inputs)]
# QR decomposition
@register_opt('magma', 'fast_compile')
@op_lifter([nlinalg.QRFull])
@register_opt2([theano.tensor.nlinalg.QRFull], 'magma', 'fast_compile')
def local_gpu_magma_qr(op, context_name, inputs, outputs):
if not config.magma.enabled or op.mode != 'reduced':
return
return GpuMagmaQR()
@register_opt('magma', 'fast_compile')
@op_lifter([nlinalg.QRIncomplete])
@register_opt2([theano.tensor.nlinalg.QRIncomplete], 'magma', 'fast_compile')
def local_gpu_magma_qr_incomplete(op, context_name, inputs, outputs):
if not config.magma.enabled:
return
return GpuMagmaQR(complete=False)
# Matrix inverse
@register_opt('magma', 'fast_compile')
@op_lifter([nlinalg.MatrixInverse])
@register_opt2([theano.tensor.nlinalg.MatrixInverse], 'magma', 'fast_compile')
def local_gpu_matrix_inverse(op, context_name, inputs, outputs):
def local_gpu_magma_matrix_inverse(op, context_name, inputs, outputs):
if not config.magma.enabled:
return
if inputs[0].dtype not in ['float16', 'float32']:
......@@ -2178,7 +2198,7 @@ def local_gpu_matrix_inverse(op, context_name, inputs, outputs):
@register_inplace()
@local_optimizer([GpuMagmaMatrixInverse])
def local_inplace_gpu_matrix_inverse(node):
def local_inplace_gpu_magma_matrix_inverse(node):
if isinstance(node.op, GpuMagmaMatrixInverse) and not node.op.inplace:
return [node.op.clone_inplace()(*node.inputs)]
......@@ -2187,7 +2207,7 @@ def local_inplace_gpu_matrix_inverse(node):
@register_opt('magma', 'fast_compile')
@op_lifter([nlinalg.SVD])
@register_opt2([theano.tensor.nlinalg.SVD], 'magma', 'fast_compile')
def local_gpu_svd(op, context_name, inputs, outputs):
def local_gpu_magma_svd(op, context_name, inputs, outputs):
if not config.magma.enabled:
return
if inputs[0].dtype not in ['float16', 'float32']:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论