提交 13f5c2b0 authored 作者: Thomas George's avatar Thomas George

- Added cholesky solve for symmetric matrices + tests

- Added a test in solve to check whether the LU factorization or the cholesky decomposition succeeded, otherwise a LinAlgError is raised (similar to scipy's solve) + tests
上级 d9d2d0b5
...@@ -6,6 +6,9 @@ import theano ...@@ -6,6 +6,9 @@ import theano
from theano import Op from theano import Op
from theano.gpuarray import basic_ops, GpuArrayType from theano.gpuarray import basic_ops, GpuArrayType
import numpy
from numpy.linalg.linalg import LinAlgError
try: try:
import pygpu import pygpu
except ImportError: except ImportError:
...@@ -18,6 +21,34 @@ try: ...@@ -18,6 +21,34 @@ try:
except (ImportError, OSError, RuntimeError, pkg_resources.DistributionNotFound): except (ImportError, OSError, RuntimeError, pkg_resources.DistributionNotFound):
pass pass
if cusolver_available:
# Add cusolver call as it is missing in skcuda
# SPOTRS
cusolver._libcusolver.cusolverDnSpotrs.restype = int
cusolver._libcusolver.cusolverDnSpotrs.argtypes = [cusolver.ctypes.c_void_p,
cusolver.ctypes.c_int,
cusolver.ctypes.c_int,
cusolver.ctypes.c_int,
cusolver.ctypes.c_void_p,
cusolver.ctypes.c_int,
cusolver.ctypes.c_void_p,
cusolver.ctypes.c_int,
cusolver.ctypes.c_void_p]
def cusolverDnSpotrs(handle, uplo, n, nrhs, A, lda,
B, ldb, devInfo):
"""
Solve real single precision linear system for hermitian matrices.
References
----------
`cusolverDn<t>potrs <http://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-potrs>`_
"""
status = cusolver._libcusolver.cusolverDnSpotrs(handle, uplo, n, nrhs,
int(A), lda, int(B),
ldb, int(devInfo))
cusolver.cusolverCheckStatus(status)
class GpuCusolverSolve(Op): class GpuCusolverSolve(Op):
""" """
...@@ -30,11 +61,12 @@ class GpuCusolverSolve(Op): ...@@ -30,11 +61,12 @@ class GpuCusolverSolve(Op):
""" """
__props__ = ('trans', 'inplace') __props__ = ('A_structure', 'trans', 'inplace')
def __init__(self, trans='N', inplace=False): def __init__(self, A_structure='general', trans='N', inplace=False):
self.trans = trans self.trans = trans
self.inplace = inplace self.inplace = inplace
self.A_structure = A_structure
if self.inplace: if self.inplace:
self.destroy_map = {0: [0, 1]} self.destroy_map = {0: [0, 1]}
super(GpuCusolverSolve, self).__init__() super(GpuCusolverSolve, self).__init__()
...@@ -70,6 +102,11 @@ class GpuCusolverSolve(Op): ...@@ -70,6 +102,11 @@ class GpuCusolverSolve(Op):
with ctx: with ctx:
ctx.cusolver_handle = cusolver.cusolverDnCreate() ctx.cusolver_handle = cusolver.cusolverDnCreate()
def check_dev_info(self, dev_info):
val = numpy.asarray(dev_info)[0]
if val > 0:
raise LinAlgError('A is singular')
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
context = inputs[0][0].context context = inputs[0][0].context
...@@ -116,32 +153,58 @@ class GpuCusolverSolve(Op): ...@@ -116,32 +153,58 @@ class GpuCusolverSolve(Op):
if A.flags['C_CONTIGUOUS']: if A.flags['C_CONTIGUOUS']:
trans = 1 - trans trans = 1 - trans
with context: if self.A_structure == 'symmetric':
workspace_size = cusolver.cusolverDnSgetrf_bufferSize( with context:
context.cusolver_handle, n, n, A_ptr, lda) workspace_size = cusolver.cusolverDnSpotrf_bufferSize(
context.cusolver_handle, 0, n, A_ptr, lda)
workspace = pygpu.zeros(workspace_size, dtype='float32',
context=context)
dev_info = pygpu.zeros((1,), dtype='int32', context=context)
workspace_ptr = workspace.gpudata
dev_info_ptr = dev_info.gpudata
with context:
cusolver.cusolverDnSpotrf(
context.cusolver_handle, 0, n, A_ptr, lda, workspace_ptr,
workspace_size, dev_info_ptr)
self.check_dev_info(dev_info)
cusolverDnSpotrs(
context.cusolver_handle, 0, n, m, A_ptr, lda,
b_ptr, ldb, dev_info_ptr)
else:
# general case for A
with context:
workspace_size = cusolver.cusolverDnSgetrf_bufferSize(
context.cusolver_handle, n, n, A_ptr, lda)
workspace = pygpu.zeros(workspace_size, dtype='float32', workspace = pygpu.zeros(workspace_size, dtype='float32',
context=context) context=context)
pivots = pygpu.zeros(n, dtype='int32', context=context) pivots = pygpu.zeros(n, dtype='int32', context=context)
dev_info = pygpu.zeros((1,), dtype='int32', context=context) dev_info = pygpu.zeros((1,), dtype='int32', context=context)
workspace_ptr = workspace.gpudata workspace_ptr = workspace.gpudata
pivots_ptr = pivots.gpudata pivots_ptr = pivots.gpudata
dev_info_ptr = dev_info.gpudata dev_info_ptr = dev_info.gpudata
with context: with context:
cusolver.cusolverDnSgetrf( cusolver.cusolverDnSgetrf(
context.cusolver_handle, n, n, A_ptr, lda, workspace_ptr, context.cusolver_handle, n, n, A_ptr, lda, workspace_ptr,
pivots_ptr, dev_info_ptr) pivots_ptr, dev_info_ptr)
self.check_dev_info(dev_info)
cusolver.cusolverDnSgetrs( cusolver.cusolverDnSgetrs(
context.cusolver_handle, trans, n, m, A_ptr, lda, context.cusolver_handle, trans, n, m, A_ptr, lda,
pivots_ptr, b_ptr, ldb, dev_info_ptr) pivots_ptr, b_ptr, ldb, dev_info_ptr)
z[0] = b z[0] = b
def gpu_solve(A, b, trans='N'): def gpu_solve(A, b, A_structure='general', trans='N'):
return GpuCusolverSolve(trans)(A, b) return GpuCusolverSolve(A_structure, trans)(A, b)
...@@ -7,6 +7,8 @@ import theano ...@@ -7,6 +7,8 @@ import theano
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from .config import mode_with_gpu from .config import mode_with_gpu
from numpy.linalg.linalg import LinAlgError
# Skip tests if cuda_ndarray is not available. # Skip tests if cuda_ndarray is not available.
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from theano.gpuarray.linalg import (cusolver_available, gpu_solve) from theano.gpuarray.linalg import (cusolver_available, gpu_solve)
...@@ -16,7 +18,7 @@ if not cusolver_available: ...@@ -16,7 +18,7 @@ if not cusolver_available:
class TestCusolver(unittest.TestCase): class TestCusolver(unittest.TestCase):
def run_gpu_solve(self, A_val, x_val): def run_gpu_solve(self, A_val, x_val, A_struct=None):
b_val = numpy.dot(A_val, x_val) b_val = numpy.dot(A_val, x_val)
b_val_trans = numpy.dot(A_val.T, x_val) b_val_trans = numpy.dot(A_val.T, x_val)
...@@ -24,14 +26,19 @@ class TestCusolver(unittest.TestCase): ...@@ -24,14 +26,19 @@ class TestCusolver(unittest.TestCase):
b = theano.tensor.matrix("b", dtype="float32") b = theano.tensor.matrix("b", dtype="float32")
b_trans = theano.tensor.matrix("b", dtype="float32") b_trans = theano.tensor.matrix("b", dtype="float32")
solver = gpu_solve(A, b) if A_struct is None:
solver_trans = gpu_solve(A, b_trans, trans='T') solver = gpu_solve(A, b)
solver_trans = gpu_solve(A, b_trans, trans='T')
else:
solver = gpu_solve(A, b, A_struct)
solver_trans = gpu_solve(A, b_trans, A_struct, trans='T')
fn = theano.function([A, b, b_trans], [solver, solver_trans], mode=mode_with_gpu) fn = theano.function([A, b, b_trans], [solver, solver_trans], mode=mode_with_gpu)
res = fn(A_val, b_val, b_val_trans) res = fn(A_val, b_val, b_val_trans)
x_res = numpy.array(res[0]) x_res = numpy.array(res[0])
x_res_trans = numpy.array(res[1]) x_res_trans = numpy.array(res[1])
utt.assert_allclose(x_res, x_val) utt.assert_allclose(x_val, x_res)
utt.assert_allclose(x_res_trans, x_val) utt.assert_allclose(x_val, x_res_trans)
def test_diag_solve(self): def test_diag_solve(self):
numpy.random.seed(1) numpy.random.seed(1)
...@@ -55,10 +62,10 @@ class TestCusolver(unittest.TestCase): ...@@ -55,10 +62,10 @@ class TestCusolver(unittest.TestCase):
def test_sym_solve(self): def test_sym_solve(self):
numpy.random.seed(1) numpy.random.seed(1)
A_val = numpy.random.uniform(-0.4, 0.4, (5, 5)).astype("float32") A_val = numpy.random.uniform(-0.4, 0.4, (5, 5)).astype("float32")
A_sym = (A_val + A_val.T) / 2.0 A_sym = numpy.dot(A_val, A_val.T)
x_val = numpy.random.uniform(-0.4, 0.4, (A_val.shape[1], x_val = numpy.random.uniform(-0.4, 0.4, (A_val.shape[1],
1)).astype("float32") 1)).astype("float32")
self.run_gpu_solve(A_sym, x_val) self.run_gpu_solve(A_sym, x_val, 'symmetric')
def test_orth_solve(self): def test_orth_solve(self):
numpy.random.seed(1) numpy.random.seed(1)
...@@ -74,3 +81,34 @@ class TestCusolver(unittest.TestCase): ...@@ -74,3 +81,34 @@ class TestCusolver(unittest.TestCase):
x_val = numpy.random.uniform(-0.4, 0.4, x_val = numpy.random.uniform(-0.4, 0.4,
(A_val.shape[1], 4)).astype("float32") (A_val.shape[1], 4)).astype("float32")
self.run_gpu_solve(A_val, x_val) self.run_gpu_solve(A_val, x_val)
def test_linalgerrsym_solve(self):
numpy.random.seed(1)
A_val = numpy.random.uniform(-0.4, 0.4, (5, 5)).astype("float32")
x_val = numpy.random.uniform(-0.4, 0.4,
(A_val.shape[1], 4)).astype("float32")
A_val = numpy.dot(A_val.T, A_val)
# make A singular
A_val[:, 2] = A_val[:, 1] + A_val[:, 3]
A = theano.tensor.matrix("A", dtype="float32")
b = theano.tensor.matrix("b", dtype="float32")
solver = gpu_solve(A, b, 'symmetric')
fn = theano.function([A, b], [solver], mode=mode_with_gpu)
self.assertRaises(LinAlgError, fn, A_val, x_val)
def test_linalgerr_solve(self):
numpy.random.seed(1)
A_val = numpy.random.uniform(-0.4, 0.4, (5, 5)).astype("float32")
x_val = numpy.random.uniform(-0.4, 0.4,
(A_val.shape[1], 4)).astype("float32")
# make A singular
A_val[:, 2] = 0
A = theano.tensor.matrix("A", dtype="float32")
b = theano.tensor.matrix("b", dtype="float32")
solver = gpu_solve(A, b, trans='T')
fn = theano.function([A, b], [solver], mode=mode_with_gpu)
self.assertRaises(LinAlgError, fn, A_val, x_val)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论