提交 8e72e6c0 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix test with gpu conv op.

上级 b8665e9f
import sys, time
import numpy
import theano
# Skip test if cuda_ndarray is not available.
from nose.plugins.skip import SkipTest
import theano.sandbox.cuda as cuda_ndarray
if cuda_ndarray.enable_cuda == False:
raise SkipTest('Optional package cuda disabled')
cuda_tensor4 = cuda_ndarray.CudaNdarrayType([False]*4)
def py_conv_valid_numpy(img, kern):
assert img.shape[1] == kern.shape[1]
outshp = (img.shape[0], kern.shape[0],
......@@ -76,15 +80,19 @@ def _params_allgood(ishape, kshape, mode, subsample=(1,1), img_stride=(1,1), ker
t2 = None
rval = True
try:
if True:
t0 = time.time()
cpuval = py_conv_scipy(npy_img, npy_kern, mode, subsample)
t1 = time.time()
gpuval = cuda_ndarray.conv(img, kern, mode, subsample=subsample,
version=version, verbose=verbose)
i = cuda_tensor4()
k = cuda_tensor4()
op = theano.sandbox.cuda.blas.GpuConv(border_mode=mode,subsample=subsample, version=version, verbose=verbose)(i,k)
# imshp=ishape[1:],kshp=kshape[2:],nkern=kshape[0],bsize=ishape[0])
f=theano.function([i,k],op)
gpuval = f(img,kern)
t2 = time.time()
for i in range(nb_iter):
gpuval2 = cuda_ndarray.conv(img, kern, mode, subsample=subsample,
version=version, verbose=0)
gpuval2 = f(img,kern)
assert numpy.allclose(numpy.asarray(gpuval),numpy.asarray(gpuval2))
assert (numpy.asarray(gpuval)==numpy.asarray(gpuval2)).all()
gpuval = numpy.asarray(gpuval)
......@@ -105,6 +113,7 @@ def _params_allgood(ishape, kshape, mode, subsample=(1,1), img_stride=(1,1), ker
approx_fp /= 1e6
cpu_mflops = approx_fp / (t1-t0)
gpu_mflops = approx_fp / (t2-t1)
if verbose>0:
print >> sys.stdout, '%15s'% str(ishape), '%15s'% str(kshape),
print >> sys.stdout, '%12.5f %7.2f %7.2f %7.1f' % (approx_fp,
cpu_mflops, gpu_mflops,(t1-t0)/(t2-t1))
......@@ -136,6 +145,7 @@ def _params_allgood(ishape, kshape, mode, subsample=(1,1), img_stride=(1,1), ker
return rval
def exec_conv(version, shapes, verbose, random, mode, print_=None, rtol=1e-5, ones=False):
if verbose>0:
_params_allgood_header()
nb_failed = 0
nb_tests = 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论