提交 5a75915c authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Add tests for magma cholesky

上级 8a005114
......@@ -9,7 +9,7 @@ import theano
from theano import config
from theano.gpuarray.linalg import (GpuCholesky, GpuMagmaMatrixInverse,
cusolver_available, gpu_matrix_inverse,
gpu_solve, gpu_svd)
gpu_solve, gpu_svd, GpuMagmaCholesky)
from theano.tensor.nlinalg import matrix_inverse
from theano.tests import unittest_tools as utt
......@@ -216,7 +216,7 @@ class TestMagma(unittest.TestCase):
# Copied from theano.tensor.tests.test_basic.rand.
A_val = test_rng.rand(N, N).astype('float32') * 2 - 1
A_val_inv = fn(A_val)
utt.assert_allclose(np.eye(N), np.dot(A_val_inv, A_val), atol=5e-3)
utt.assert_allclose(np.dot(A_val_inv, A_val), np.eye(N), atol=1e-2)
def test_gpu_matrix_inverse_inplace(self):
N = 1000
......@@ -296,3 +296,45 @@ class TestMagma(unittest.TestCase):
A_val = rand(100, 50).astype('float32')
utt.assert_allclose(f_cpu(A_val), f_gpu(A_val))
def run_gpu_cholesky(self, A_val, lower=True):
A = theano.tensor.fmatrix("A")
f = theano.function([A], GpuMagmaCholesky(lower=lower)(A),
mode=mode_with_gpu)
return f(A_val)
def check_cholesky(self, A, L, lower=True, rtol=None, atol=None):
if not lower:
L = L.T
utt.assert_allclose(np.dot(L, L.T), A, rtol=rtol, atol=atol)
def test_gpu_cholesky(self):
N = 1000
A = rand(N, N).astype('float32')
A = np.dot(A.T, A)
L = self.run_gpu_cholesky(A)
self.check_cholesky(A, L, atol=1e-3)
L = self.run_gpu_cholesky(A, lower=False)
self.check_cholesky(A, L, lower=False, atol=1e-3)
def test_gpu_cholesky_inplace(self):
N = 1000
A = rand(N, N).astype('float32')
A = np.dot(A.T, A)
A_gpu = gpuarray_shared_constructor(A)
A_copy = A_gpu.get_value()
fn = theano.function([], GpuMagmaCholesky(inplace=True)(A_gpu),
mode=mode_with_gpu, accept_inplace=True)
fn()
L = A_gpu.get_value()
self.check_cholesky(A_copy, L, atol=1e-3)
def test_gpu_cholesky_inplace_opt(self):
A = theano.tensor.fmatrix("A")
fn = theano.function([A], GpuMagmaCholesky()(A), mode=mode_with_gpu)
assert any([
node.op.inplace
for node in fn.maker.fgraph.toposort() if
isinstance(node.op, GpuMagmaCholesky)
])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论