提交 1494d16f authored 作者: Thomas George's avatar Thomas George

updated tests for cusolver solve

上级 90cd2169
...@@ -17,6 +17,7 @@ except (ImportError, OSError, RuntimeError, pkg_resources.DistributionNotFound): ...@@ -17,6 +17,7 @@ except (ImportError, OSError, RuntimeError, pkg_resources.DistributionNotFound):
cusolver_handle = None cusolver_handle = None
class GpuCusolverSolve(Op): class GpuCusolverSolve(Op):
""" """
CUSOLVER GPU solver OP. CUSOLVER GPU solver OP.
...@@ -108,9 +109,9 @@ class GpuCusolverSolve(Op): ...@@ -108,9 +109,9 @@ class GpuCusolverSolve(Op):
ldb = max(1, k, m) ldb = max(1, k, m)
# We copy A and b as cusolver operates inplace # We copy A and b as cusolver operates inplace
b = gpuarray.array(b, copy=True, order='F')
if not self.inplace: if not self.inplace:
A = gpuarray.array(A, copy=True) A = gpuarray.array(A, copy=True)
b = gpuarray.array(b, copy=True, order='F')
A_ptr = A.gpudata A_ptr = A.gpudata
b_ptr = b.gpudata b_ptr = b.gpudata
......
...@@ -16,19 +16,22 @@ if not cusolver_available: ...@@ -16,19 +16,22 @@ if not cusolver_available:
class TestCusolver(unittest.TestCase): class TestCusolver(unittest.TestCase):
def run_gpu_solve(self, A_val, x_val, trans='N'): def run_gpu_solve(self, A_val, x_val):
if trans == 'N': b_val = numpy.dot(A_val, x_val)
b_val = numpy.dot(A_val, x_val) b_val_trans = numpy.dot(A_val.T, x_val)
else:
b_val = numpy.dot(A_val.T, x_val)
A = theano.tensor.matrix("A", dtype="float32") A = theano.tensor.matrix("A", dtype="float32")
b = theano.tensor.matrix("b", dtype="float32") b = theano.tensor.matrix("b", dtype="float32")
b_trans = theano.tensor.matrix("b", dtype="float32")
solver = gpu_solve(A, b, trans) solver = gpu_solve(A, b)
fn = theano.function([A, b], [solver], mode=mode_with_gpu) solver_trans = gpu_solve(A, b_trans, trans='T')
res = fn(A_val, b_val) fn = theano.function([A, b, b_trans], [solver, solver_trans], mode=mode_with_gpu)
res = fn(A_val, b_val, b_val_trans)
x_res = numpy.array(res[0]) x_res = numpy.array(res[0])
x_res_trans = numpy.array(res[1])
utt.assert_allclose(x_res, x_val) utt.assert_allclose(x_res, x_val)
utt.assert_allclose(x_res_trans, x_val)
def test_diag_solve(self): def test_diag_solve(self):
numpy.random.seed(1) numpy.random.seed(1)
...@@ -60,10 +63,3 @@ class TestCusolver(unittest.TestCase): ...@@ -60,10 +63,3 @@ class TestCusolver(unittest.TestCase):
x_val = numpy.random.uniform(-0.4, 0.4, x_val = numpy.random.uniform(-0.4, 0.4,
(A_val.shape[1], 4)).astype("float32") (A_val.shape[1], 4)).astype("float32")
self.run_gpu_solve(A_val, x_val) self.run_gpu_solve(A_val, x_val)
def test_uni_rand_solve_transpose(self):
numpy.random.seed(1)
A_val = numpy.random.uniform(-0.4, 0.4, (5, 5)).astype("float32")
x_val = numpy.random.uniform(-0.4, 0.4,
(A_val.shape[1], 4)).astype("float32")
self.run_gpu_solve(A_val, x_val, trans='T')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论