提交 ace491ba authored 作者: Thomas George's avatar Thomas George

op lifter for cholesky

上级 bc6ceb89
...@@ -337,3 +337,7 @@ class GpuCholesky(Op): ...@@ -337,3 +337,7 @@ class GpuCholesky(Op):
triu(L) triu(L)
outputs[0][0] = L outputs[0][0] = L
def gpu_cholesky(A, lower=True):
return GpuCholesky(lower)(A)
...@@ -70,7 +70,7 @@ from .subtensor import (GpuIncSubtensor, GpuSubtensor, ...@@ -70,7 +70,7 @@ from .subtensor import (GpuIncSubtensor, GpuSubtensor,
GpuAdvancedIncSubtensor1_dev20) GpuAdvancedIncSubtensor1_dev20)
from .opt_util import alpha_merge, output_merge, pad_dims, unpad_dims from .opt_util import alpha_merge, output_merge, pad_dims, unpad_dims
from .reduction import GpuMaxAndArgmax from .reduction import GpuMaxAndArgmax
from .linalg import (GpuCusolverSolve, cusolver_available) from .linalg import (GpuCusolverSolve, GpuCholesky, cusolver_available)
_logger = logging.getLogger("theano.gpuarray.opt") _logger = logging.getLogger("theano.gpuarray.opt")
...@@ -1967,6 +1967,16 @@ def local_gpu_solve(op, context_name, inputs, outputs): ...@@ -1967,6 +1967,16 @@ def local_gpu_solve(op, context_name, inputs, outputs):
return return
return GpuCusolverSolve() return GpuCusolverSolve()
# Cholesky decomposition
@register_opt('fast_compile')
@op_lifter([slinalg.Cholesky])
@register_opt2([theano.tensor.slinalg.Cholesky], 'fast_compile')
def local_gpu_cholesky(op, context_name, inputs, outputs):
if not cusolver_available:
return
return GpuCholesky()
# Do not register in fast_run or fast_compile. # Do not register in fast_run or fast_compile.
# It will be added to fast_run if the GPU is enabled. # It will be added to fast_run if the GPU is enabled.
optdb.register('gpua_scanOp_make_inplace', optdb.register('gpua_scanOp_make_inplace',
......
...@@ -158,9 +158,9 @@ class TestGpuCholesky(unittest.TestCase): ...@@ -158,9 +158,9 @@ class TestGpuCholesky(unittest.TestCase):
def test_diag_chol(self): def test_diag_chol(self):
# Diagonal matrix input Cholesky test. # Diagonal matrix input Cholesky test.
# make sure all diagonal elements are positive so positive-definite
for lower in [True, False]: for lower in [True, False]:
for inplace in [True, False]: for inplace in [True, False]:
# make sure all diagonal elements are positive so positive-definite
A_val = np.diag(np.random.uniform(size=5).astype("float32") + 1) A_val = np.diag(np.random.uniform(size=5).astype("float32") + 1)
self.compare_gpu_cholesky_to_np(A_val, lower=lower, inplace=inplace) self.compare_gpu_cholesky_to_np(A_val, lower=lower, inplace=inplace)
......
...@@ -17,7 +17,7 @@ from ..basic_ops import ( ...@@ -17,7 +17,7 @@ from ..basic_ops import (
from ..blas import GpuGemm from ..blas import GpuGemm
from ..elemwise import GpuCAReduceCuda, GpuCAReduceCPY, GpuElemwise from ..elemwise import GpuCAReduceCuda, GpuCAReduceCPY, GpuElemwise
from ..subtensor import GpuSubtensor from ..subtensor import GpuSubtensor
from ..linalg import GpuCusolverSolve, cusolver_available from ..linalg import GpuCusolverSolve, cusolver_available, GpuCholesky
from .config import mode_with_gpu, mode_without_gpu, test_ctx_name, SkipTest from .config import mode_with_gpu, mode_without_gpu, test_ctx_name, SkipTest
...@@ -584,6 +584,23 @@ def test_local_lift_solve(): ...@@ -584,6 +584,23 @@ def test_local_lift_solve():
utt.assert_allclose(f_cpu(A_val, b_val), f_gpu(A_val, b_val)) utt.assert_allclose(f_cpu(A_val, b_val), f_gpu(A_val, b_val))
def test_local_lift_cholesky():
if not cusolver_available:
raise SkipTest('No cuSolver')
A = tensor.fmatrix()
o = slinalg.cholesky(A)
f_cpu = theano.function([A], o)
f_gpu = theano.function([A], o, mode=mode_with_gpu)
assert not any(isinstance(n.op, slinalg.Cholesky)
for n in f_gpu.maker.fgraph.apply_nodes)
assert any(isinstance(n.op, GpuCholesky)
for n in f_gpu.maker.fgraph.apply_nodes)
M_val = np.random.normal(size=(3, 3)).astype("float32")
# A = M.dot(M) will be positive definite for all non-singular M
A_val = M_val.dot(M_val.T)
utt.assert_allclose(f_cpu(A_val), f_gpu(A_val))
def test_local_gpua_advanced_incsubtensor(): def test_local_gpua_advanced_incsubtensor():
# test a corner case reported at gh-5589 # test a corner case reported at gh-5589
target = tensor.ftensor4() target = tensor.ftensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论