提交 4e114d97 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use _allclose for comparison, since it is more lax on float32 data.

上级 05a596ca
...@@ -230,7 +230,7 @@ def test_dot(): ...@@ -230,7 +230,7 @@ def test_dot():
b0 = cuda_ndarray.CudaNdarray(a0) b0 = cuda_ndarray.CudaNdarray(a0)
b1 = cuda_ndarray.CudaNdarray(a1) b1 = cuda_ndarray.CudaNdarray(a1)
assert numpy.allclose(numpy.dot(a0, a1), cuda_ndarray.dot(b0, b1)) assert _allclose(numpy.dot(a0, a1), cuda_ndarray.dot(b0, b1))
a1 = theano._asarray(rng.randn(6, 7), dtype='float32') a1 = theano._asarray(rng.randn(6, 7), dtype='float32')
...@@ -240,7 +240,7 @@ def test_dot(): ...@@ -240,7 +240,7 @@ def test_dot():
transposed = cuda_ndarray.dimshuffle(b1,(1,0)) transposed = cuda_ndarray.dimshuffle(b1,(1,0))
cuda_version = cuda_ndarray.dot(b0, transposed) cuda_version = cuda_ndarray.dot(b0, transposed)
assert numpy.allclose( numpy_version, cuda_version) assert _allclose(numpy_version, cuda_version)
a1 = theano._asarray(rng.randn(7, 6), dtype='float32') a1 = theano._asarray(rng.randn(7, 6), dtype='float32')
b1 = cuda_ndarray.CudaNdarray(a1) b1 = cuda_ndarray.CudaNdarray(a1)
...@@ -249,12 +249,15 @@ def test_dot(): ...@@ -249,12 +249,15 @@ def test_dot():
a0 = theano._asarray(rng.randn(7, 4), dtype='float32') a0 = theano._asarray(rng.randn(7, 4), dtype='float32')
b0 = cuda_ndarray.CudaNdarray(a0) b0 = cuda_ndarray.CudaNdarray(a0)
assert numpy.allclose(numpy.dot(a0.T, a1), cuda_ndarray.dot( cuda_ndarray.dimshuffle(b0,(1,0)), b1)) assert _allclose(numpy.dot(a0.T, a1),
cuda_ndarray.dot(cuda_ndarray.dimshuffle(b0,(1,0)), b1))
a1 = theano._asarray(rng.randn(6, 7), dtype='float32') a1 = theano._asarray(rng.randn(6, 7), dtype='float32')
b1 = cuda_ndarray.CudaNdarray(a1) b1 = cuda_ndarray.CudaNdarray(a1)
assert numpy.allclose(numpy.dot(a0.T, a1.T), cuda_ndarray.dot( cuda_ndarray.dimshuffle(b0,(1,0)), cuda_ndarray.dimshuffle(b1,(1,0) ) ) ) assert _allclose(numpy.dot(a0.T, a1.T),
cuda_ndarray.dot(cuda_ndarray.dimshuffle(b0,(1,0)),
cuda_ndarray.dimshuffle(b1,(1,0))))
def test_sum(): def test_sum():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论