提交 77a62e79 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add tests for blas and fix an missing import in opt.py.

上级 dbeedc15
...@@ -15,7 +15,7 @@ from theano.sandbox.gpuarray.basic_ops import (host_from_gpu, gpu_from_host, ...@@ -15,7 +15,7 @@ from theano.sandbox.gpuarray.basic_ops import (host_from_gpu, gpu_from_host,
from theano.sandbox.gpuarray.elemwise import (GpuElemwise, _is_scalar, from theano.sandbox.gpuarray.elemwise import (GpuElemwise, _is_scalar,
GpuDimShuffle, GpuCAReduce) GpuDimShuffle, GpuCAReduce)
from theano.sandbox.gpuarray.subtensor import GpuSubtensor from theano.sandbox.gpuarray.subtensor import GpuSubtensor
from theano.sandbox.gpuarray.blas import GpuGemv from theano.sandbox.gpuarray.blas import GpuGemv, GpuGemm
gpu_optimizer = EquilibriumDB() gpu_optimizer = EquilibriumDB()
gpu_cut_copies = EquilibriumDB() gpu_cut_copies = EquilibriumDB()
......
from unittest import TestCase
from theano.tensor.blas import gemv_inplace, gemm_inplace
from theano.sandbox.gpuarray.tests.test_basic_ops import makeTester, rand
from theano.sandbox.gpuarray.blas import (gpugemv_inplace,
gpugemm_inplace)
GpuGemvTester = makeTester('GpuGemvTester',
op=gemv_inplace, gpu_op=gpugemv_inplace,
cases=dict(
dot_vv=[rand(1), 1, rand(1, 2), rand(2), 0],
dot_vm=[rand(3), 1, rand(3, 2), rand(2), 0],
# test_02=[rand(0), 1, rand(0, 2), rand(2), 0],
# test_30=[rand(3), 1, rand(3, 0), rand(0), 0],
# test_00=[rand(0), 1, rand(0, 0), rand(0), 0],
test_stride=[rand(3)[::-1], 1, rand(3, 2)[::-1], rand(2)[::-1], 0],
)
)
GpuGemmTester = makeTester('GpuGemmTester',
op=gemm_inplace, gpu_op=gpugemm_inplace,
cases=dict(
test1=[rand(3, 4), 1.0, rand(3, 5), rand(5, 4), 0.0],
test2=[rand(3, 4), 1.0, rand(3, 5), rand(5, 4), 1.0],
test3=[rand(3, 4), 1.0, rand(3, 5), rand(5, 4), -1.0],
test4=[rand(3, 4), 0.0, rand(3, 5), rand(5, 4), 0.0],
test5=[rand(3, 4), 0.0, rand(3, 5), rand(5, 4), 0.6],
test6=[rand(3, 4), 0.0, rand(3, 5), rand(5, 4), -1.0],
test7=[rand(3, 4), -1.0, rand(3, 5), rand(5, 4), 0.0],
test8=[rand(3, 4), -1.0, rand(3, 5), rand(5, 4), 1.0],
test9=[rand(3, 4), -1.0, rand(3, 5), rand(5, 4), -1.0],
)
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论