提交 a0653671 authored 作者: Arjun Jain's avatar Arjun Jain

Feature: support for all stride sized. This makes it work for all subsample (x,…

Feature: support for all stride sized. This makes it work for all subsample (x, y) values and all kernel, batch, image size compatible on GPU. @nouiz: Can you please run the tests once again just to be double sure? I think it works correctly.
上级 1b680cc3
......@@ -517,9 +517,6 @@ class GpuCorrMM(GpuOp):
if pad != 0:
raise NotImplementedError(
"GpuCorrMM don't implement the pad parameter")
if subsample != (1, 1):
raise NotImplementedError(
"GpuCorrMM we don't implement the subsample parameter")
def __eq__(self, other):
return type(self) == type(other) \
......@@ -662,7 +659,7 @@ class GpuCorrMM(GpuOp):
}
out2 = corrMM(%(img)s, %(kern)s, %(out)s, padH, padW);
out2 = corrMM(%(img)s, %(kern)s, %(out)s, dx, dy, padH, padW);
if (out2==NULL){
%(fail)s
}
......
......@@ -92,7 +92,10 @@ void im2col(const float* data_im, const int channels,
CudaNdarray* corrMM(const CudaNdarray *input,
CudaNdarray *weight,
CudaNdarray *output,
int padH, int padW = 0)
int dH = 1,
int dW = 1,
int padH = 0,
int padW = 0)
{
cublasStatus_t status;
......@@ -107,9 +110,6 @@ CudaNdarray* corrMM(const CudaNdarray *input,
PyErr_SetString(PyExc_ValueError, "required weight of 4D");
}
// TODO: stride(dW, dH) and padding as function parameter
int dH = 1;
int dW = 1;
int kH = CudaNdarray_HOST_DIMS(weight)[2];
int kW = CudaNdarray_HOST_DIMS(weight)[3];
int nInputPlane = CudaNdarray_HOST_DIMS(input)[1];
......
......@@ -644,10 +644,6 @@ def test_valid():
# Add tests with strided inputs by still square images and filters.
shapes += get_shapes2(scales_img=(2, 2), img_stride=(2, 2))
shapes += get_shapes2(scales_kern=(2, 2), kern_stride=(2, 2))
# Keep only tests with square images and filters even with inputs strides
shapes = [shp for shp in shapes if (
shp[0][2]/shp[3][0] == shp[0][3]/shp[3][1] and
shp[1][2]/shp[4][0] == shp[1][3]/shp[4][1])]
exec_conv(version, shapes, verbose, random, 'valid',
print_=print_, ones=ones, rtol=1.1e-5,
theano_mode=mode, cls=cuda.blas.GpuCorrMM)
......@@ -718,8 +714,6 @@ def test_full():
# Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm")
shapes = [shp for shp in shapes if shp[1][2] == shp[1][3]]
shapes = [shp for shp in shapes if shp[0][2] == shp[0][3]]
shapes = shapes[0:10]
exec_conv(version, shapes, verbose, random, 'full',
theano_mode=mode, cls=cuda.blas.GpuCorrMM)
......@@ -841,38 +835,42 @@ def test_gemm():
for rImg2 in range(5, 9):
for rFlt1 in range(2, 4):
for rFlt2 in range(2, 4):
ishape = (bs, ch, rImg1, rImg2)
kshape = (nf, ch, rFlt1, rFlt2)
print "ishape: ", ishape
print "kshape: ", kshape
subsample = (1, 1)
npy_img = theano._asarray(numpy.random.rand(*ishape), dtype='float32')
npy_kern = theano._asarray(numpy.random.rand(*kshape), dtype='float32')
i = cuda_tensor4()
k = cuda_tensor4()
t2 = None
t0 = time.time()
cpuval = py_conv(npy_img, npy_kern, mode, subsample)
t1 = time.time()
op = theano.sandbox.cuda.blas.GpuCorrMM(border_mode=mode)(i, k)
f = theano.function([i, k], op, mode=theano_mode)
npy_kern = npy_kern[:,:,::-1,::-1]
gpuval = f(npy_img, npy_kern)
t2 = time.time()
gpuval = numpy.asarray(gpuval)
rval = numpy.allclose(cpuval, gpuval, rtol=1e-4)
assert (rval == True)
print 'Test Passed'
for subsx in range(1, 3):
for subsy in range(1, 3):
ishape = (bs, ch, rImg1, rImg2)
kshape = (nf, ch, rFlt1, rFlt2)
print "ishape: ", ishape
print "kshape: ", kshape
subsample = (subsx, subsy)
print "subsample: ", subsample
npy_img = theano._asarray(numpy.random.rand(*ishape), dtype='float32')
npy_kern = theano._asarray(numpy.random.rand(*kshape), dtype='float32')
i = cuda_tensor4()
k = cuda_tensor4()
t2 = None
t0 = time.time()
cpuval = py_conv(npy_img, npy_kern, mode, subsample)
t1 = time.time()
op = theano.sandbox.cuda.blas.GpuCorrMM(border_mode=mode, \
subsample=subsample)(i, k)
f = theano.function([i, k], op, mode=theano_mode)
npy_kern = npy_kern[:,:,::-1,::-1]
gpuval = f(npy_img, npy_kern)
t2 = time.time()
gpuval = numpy.asarray(gpuval)
rval = numpy.allclose(cpuval, gpuval, rtol=1e-4)
assert (rval == True)
print 'Test Passed'
def benchmark():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论