提交 e2ee93fc authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6060 from notoraptor/fix-magma-buildbot

(small fix) Increase error tolerance for magma matrix_inverse tests.
...@@ -212,18 +212,21 @@ class TestMagma(unittest.TestCase): ...@@ -212,18 +212,21 @@ class TestMagma(unittest.TestCase):
fn = theano.function([A], gpu_matrix_inverse(A), mode=mode_with_gpu) fn = theano.function([A], gpu_matrix_inverse(A), mode=mode_with_gpu)
N = 1000 N = 1000
A_val = rand(N, N).astype('float32') test_rng = np.random.RandomState(seed=1)
# Copied from theano.tensor.tests.test_basic.rand.
A_val = test_rng.rand(N, N).astype('float32') * 2 - 1
A_val_inv = fn(A_val) A_val_inv = fn(A_val)
utt.assert_allclose(np.dot(A_val_inv, A_val), np.eye(N), atol=1e-3) utt.assert_allclose(np.eye(N), np.dot(A_val_inv, A_val), atol=5e-3)
def test_gpu_matrix_inverse_inplace(self): def test_gpu_matrix_inverse_inplace(self):
N = 1000 N = 1000
A_val_gpu = gpuarray_shared_constructor(rand(N, N).astype('float32')) test_rng = np.random.RandomState(seed=1)
A_val_gpu = gpuarray_shared_constructor(test_rng.rand(N, N).astype('float32') * 2 - 1)
A_val_copy = A_val_gpu.get_value() A_val_copy = A_val_gpu.get_value()
fn = theano.function([], GpuMagmaMatrixInverse(inplace=True)(A_val_gpu), fn = theano.function([], GpuMagmaMatrixInverse(inplace=True)(A_val_gpu),
mode=mode_with_gpu, accept_inplace=True) mode=mode_with_gpu, accept_inplace=True)
fn() fn()
utt.assert_allclose(np.dot(A_val_gpu.get_value(), A_val_copy), np.eye(N), atol=1e-3) utt.assert_allclose(np.eye(N), np.dot(A_val_gpu.get_value(), A_val_copy), atol=5e-3)
def test_gpu_matrix_inverse_inplace_opt(self): def test_gpu_matrix_inverse_inplace_opt(self):
A = theano.tensor.fmatrix("A") A = theano.tensor.fmatrix("A")
......
...@@ -352,9 +352,9 @@ class WrongValue(Exception): ...@@ -352,9 +352,9 @@ class WrongValue(Exception):
return s + str_diagnostic(self.val1, self.val2, self.rtol, self.atol) return s + str_diagnostic(self.val1, self.val2, self.rtol, self.atol)
def assert_allclose(val1, val2, rtol=None, atol=None): def assert_allclose(expected, value, rtol=None, atol=None):
if not T.basic._allclose(val1, val2, rtol, atol): if not T.basic._allclose(expected, value, rtol, atol):
raise WrongValue(val1, val2, rtol, atol) raise WrongValue(expected, value, rtol, atol)
class AttemptManyTimes: class AttemptManyTimes:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论