提交 4b38388c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update tests, gpu_gemv is used now, not gemm/dot22

上级 c954a65f
......@@ -9,7 +9,6 @@ if cuda_ndarray.cuda_available == False:
raise SkipTest('Optional package cuda disabled')
import theano.sandbox.cuda as cuda
import theano.sandbox.cuda.blas as blasop
### Tolerance factor used in this tests !!!
atol = 1e-6
......@@ -38,9 +37,9 @@ def test_dot_vm():
assert numpy.allclose(no_gpu_f(), gpu_f(), atol = atol)
assert numpy.allclose(no_gpu_f(), gpu_f2(), atol = atol)
# Assert that the gpu version actually uses gpu
assert sum([isinstance(node.op, blasop.GpuDot22) for node in
assert sum([node.op is cuda.blas.gpu_gemv_inplace for node in
gpu_f.maker.env.toposort() ]) == 1
assert sum([isinstance(node.op, blasop.GpuDot22) for node in
assert sum([node.op is cuda.blas.gpu_gemv_inplace for node in
gpu_f2.maker.env.toposort() ]) == 1
def test_dot_mv():
......@@ -58,9 +57,9 @@ def test_dot_mv():
assert numpy.allclose(no_gpu_f(), gpu_f(), atol = atol)
assert numpy.allclose(no_gpu_f(), gpu_f2(), atol = atol)
# Assert that the gpu version actually uses gpu
assert sum([isinstance(node.op, blasop.GpuDot22) for node in
assert sum([node.op is cuda.blas.gpu_gemv_inplace for node in
gpu_f.maker.env.toposort() ]) == 1
assert sum([isinstance(node.op, blasop.GpuDot22) for node in
assert sum([node.op is cuda.blas.gpu_gemv_inplace for node in
gpu_f2.maker.env.toposort() ]) == 1
def test_gemv1():
......@@ -79,8 +78,10 @@ def test_gemv1():
assert numpy.allclose(no_gpu_f(), gpu_f(), atol = atol)
assert numpy.allclose(no_gpu_f(), gpu_f2(), atol = atol)
# Assert that the gpu version actually uses gpu
assert sum([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f2.maker.env.toposort()]) == 1
assert sum([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f.maker.env.toposort()]) == 1
assert sum([node.op is cuda.blas.gpu_gemv_inplace for node in
gpu_f2.maker.env.toposort()]) == 1
assert sum([node.op is cuda.blas.gpu_gemv_inplace for node in
gpu_f.maker.env.toposort()]) == 1
def test_gemv2():
......@@ -99,8 +100,10 @@ def test_gemv2():
assert numpy.allclose(no_gpu_f(), gpu_f(), atol = atol)
assert numpy.allclose(no_gpu_f(), gpu_f2(), atol = atol)
# Assert that the gpu version actually uses gpu
assert sum([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f2.maker.env.toposort()]) == 1
assert sum([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f.maker.env.toposort()]) == 1
assert sum([node.op is cuda.blas.gpu_gemv_inplace for node in
gpu_f2.maker.env.toposort()]) == 1
assert sum([node.op is cuda.blas.gpu_gemv_inplace for node in
gpu_f.maker.env.toposort()]) == 1
if __name__=='__main__':
test_dot_vm()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论