提交 1f9cc65c authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Create local optimization group for matrix ops

Matrix ops optimization group is created. Magma by default take precedence during optimization. If magma is not available, cusolver cholesky decomposition is used.
上级 5a75915c
......@@ -2108,9 +2108,6 @@ def local_inplace_gpu_solve(node):
# Cholesky decomposition
@register_opt('fast_compile')
@op_lifter([slinalg.Cholesky])
@register_opt2([theano.tensor.slinalg.Cholesky], 'fast_compile')
def local_gpu_cholesky(op, context_name, inputs, outputs):
if not cusolver_available:
return
......@@ -2121,6 +2118,19 @@ def local_gpu_cholesky(op, context_name, inputs, outputs):
return op(inputs[0].astype('float32')).astype('float16')
return op
matrix_ops_db = LocalGroupDB()
matrix_ops_db2 = LocalGroupDB(local_opt=theano.gof.opt.GraphToGPULocalOptGroup)
matrix_ops_db2.__name__ = "matrix_ops_db2"
lifter = op_lifter([slinalg.Cholesky])(local_gpu_cholesky)
matrix_ops_db.register("local_gpu_cholesky", lifter,
'gpuarray', 'fast_compile', 'fast_run',
position=1)
matrix_ops_db2.register("local_gpu_cholesky",
local_optimizer([slinalg.Cholesky])(local_gpu_cholesky),
'gpuarray', 'fast_compile', 'fast_run',
position=1)
register_opt('fast_compile', name='matrix_ops_db')(matrix_ops_db)
register_opt2([slinalg.Solve], 'fast_compile', name='matrix_ops_db2')(matrix_ops_db2)
@register_inplace()
......@@ -2130,13 +2140,18 @@ def local_inplace_cholesky(node):
return [node.op.clone_inplace()(*node.inputs)]
@register_opt('magma', 'fast_compile')
@op_lifter([slinalg.cholesky, GpuCholesky])
@register_opt2([slinalg.Cholesky, GpuCholesky], 'magma', 'fast_compile')
def local_gpu_magma_cholesky(op, context_name, inputs, outputs):
if not config.magma.enabled:
return
return GpuMagmaCholesky(lower=op.lower, inplace=op.destructive)
lifter = op_lifter([slinalg.Cholesky])(local_gpu_magma_cholesky)
matrix_ops_db.register("local_gpu_magma_cholesky", lifter,
'gpuarray', 'fast_compile', 'fast_run', 'magma',
position=0)
matrix_ops_db2.register("local_gpu_magma_cholesky",
local_optimizer([slinalg.Cholesky])(local_gpu_magma_cholesky),
'gpuarray', 'fast_compile', 'fast_run', 'magma',
position=0)
@register_inplace()
......
......@@ -11,6 +11,7 @@ from theano.gpuarray.linalg import (GpuCholesky, GpuMagmaMatrixInverse,
cusolver_available, gpu_matrix_inverse,
gpu_solve, gpu_svd, GpuMagmaCholesky)
from theano.tensor.nlinalg import matrix_inverse
from theano.tensor.slinalg import cholesky
from theano.tests import unittest_tools as utt
from .. import gpuarray_shared_constructor
......@@ -132,7 +133,8 @@ class TestGpuCholesky(unittest.TestCase):
A = theano.tensor.matrix("A", dtype="float32")
cholesky_op = GpuCholesky(lower=lower, inplace=inplace)
chol_A = cholesky_op(A)
return theano.function([A], chol_A, accept_inplace=inplace, mode=mode_with_gpu)
return theano.function([A], chol_A, accept_inplace=inplace,
mode=mode_with_gpu.excluding('magma'))
def compare_gpu_cholesky_to_np(self, A_val, lower=True, inplace=False):
# Helper function to compare op output to np.cholesky output.
......@@ -144,6 +146,12 @@ class TestGpuCholesky(unittest.TestCase):
chol_A_res = np.array(res)
utt.assert_allclose(chol_A_res, chol_A_val)
def test_gpu_cholesky_opt(self):
A = theano.tensor.matrix("A", dtype="float32")
fn = theano.function([A], cholesky(A), mode=mode_with_gpu.excluding('magma'))
assert any([isinstance(node.op, GpuCholesky)
for node in fn.maker.fgraph.toposort()])
def test_invalid_input_fail_non_square(self):
# Invalid Cholesky input test with non-square matrix as input.
A_val = np.random.normal(size=(3, 2)).astype("float32")
......@@ -318,6 +326,12 @@ class TestMagma(unittest.TestCase):
L = self.run_gpu_cholesky(A, lower=False)
self.check_cholesky(A, L, lower=False, atol=1e-3)
def test_gpu_cholesky_opt(self):
A = theano.tensor.matrix("A", dtype="float32")
fn = theano.function([A], cholesky(A), mode=mode_with_gpu)
assert any([isinstance(node.op, GpuMagmaCholesky)
for node in fn.maker.fgraph.toposort()])
def test_gpu_cholesky_inplace(self):
N = 1000
A = rand(N, N).astype('float32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论