提交 1494d16f authored 作者: Thomas George's avatar Thomas George

updated tests for cusolver solve

上级 90cd2169
......@@ -17,6 +17,7 @@ except (ImportError, OSError, RuntimeError, pkg_resources.DistributionNotFound):
cusolver_handle = None
class GpuCusolverSolve(Op):
"""
CUSOLVER GPU solver OP.
......@@ -108,9 +109,9 @@ class GpuCusolverSolve(Op):
ldb = max(1, k, m)
# We copy A and b as cusolver operates inplace
b = gpuarray.array(b, copy=True, order='F')
if not self.inplace:
A = gpuarray.array(A, copy=True)
b = gpuarray.array(b, copy=True, order='F')
A_ptr = A.gpudata
b_ptr = b.gpudata
......
......@@ -16,19 +16,22 @@ if not cusolver_available:
class TestCusolver(unittest.TestCase):
def run_gpu_solve(self, A_val, x_val, trans='N'):
if trans == 'N':
b_val = numpy.dot(A_val, x_val)
else:
b_val = numpy.dot(A_val.T, x_val)
def run_gpu_solve(self, A_val, x_val):
b_val = numpy.dot(A_val, x_val)
b_val_trans = numpy.dot(A_val.T, x_val)
A = theano.tensor.matrix("A", dtype="float32")
b = theano.tensor.matrix("b", dtype="float32")
b_trans = theano.tensor.matrix("b", dtype="float32")
solver = gpu_solve(A, b, trans)
fn = theano.function([A, b], [solver], mode=mode_with_gpu)
res = fn(A_val, b_val)
solver = gpu_solve(A, b)
solver_trans = gpu_solve(A, b_trans, trans='T')
fn = theano.function([A, b, b_trans], [solver, solver_trans], mode=mode_with_gpu)
res = fn(A_val, b_val, b_val_trans)
x_res = numpy.array(res[0])
x_res_trans = numpy.array(res[1])
utt.assert_allclose(x_res, x_val)
utt.assert_allclose(x_res_trans, x_val)
def test_diag_solve(self):
numpy.random.seed(1)
......@@ -60,10 +63,3 @@ class TestCusolver(unittest.TestCase):
x_val = numpy.random.uniform(-0.4, 0.4,
(A_val.shape[1], 4)).astype("float32")
self.run_gpu_solve(A_val, x_val)
def test_uni_rand_solve_transpose(self):
numpy.random.seed(1)
A_val = numpy.random.uniform(-0.4, 0.4, (5, 5)).astype("float32")
x_val = numpy.random.uniform(-0.4, 0.4,
(A_val.shape[1], 4)).astype("float32")
self.run_gpu_solve(A_val, x_val, trans='T')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论