提交 c46a0243 authored 作者: f0k's avatar f0k

Added test_gemm_grads() test to test whether the gradients of GpuCorrMM are correct

上级 e456ce12
...@@ -842,7 +842,6 @@ class TestConv2DGPU(unittest.TestCase): ...@@ -842,7 +842,6 @@ class TestConv2DGPU(unittest.TestCase):
theano_mode = theano_mode_orig theano_mode = theano_mode_orig
def test_gemm_directly(): def test_gemm_directly():
for direction in ['fprop', 'bprop img', 'bprop kern']: for direction in ['fprop', 'bprop img', 'bprop kern']:
print 'Testing direction: ' + direction print 'Testing direction: ' + direction
...@@ -893,6 +892,56 @@ def test_gemm_directly(): ...@@ -893,6 +892,56 @@ def test_gemm_directly():
print "subsample: ", subsample print "subsample: ", subsample
assert False assert False
def test_gemm_grads():
for mode in 'valid', 'full':
for bs in range(1, 5):
for ch in range(1,4):
for nf in range(1,4):
for rImg1 in range(5, 9):
for rImg2 in range(5, 9):
for rFlt1 in range(2, 4):
for rFlt2 in range(2, 4):
for subsx in range(1, 3):
for subsy in range(1, 3):
ishape = (bs, ch, rImg1, rImg2)
kshape = (nf, ch, rFlt1, rFlt2)
subsample = (subsx, subsy)
npy_img = theano._asarray(numpy.random.rand(*ishape), dtype='float32')
npy_kern = theano._asarray(numpy.random.rand(*kshape), dtype='float32')
i = cuda_tensor4()
k = cuda_tensor4()
pad = 'auto' if mode == 'full' else (0, 0)
# TODO: also test custom pad values
corr_op = theano.sandbox.cuda.blas.GpuCorrMM(
'valid', subsample, pad)(i, k)
conv_op = tensor.nnet.conv2d(i, k[:,:,::-1,::-1],
ishape, kshape, mode, subsample)
f = theano.function([i, k],
[corr_op,
theano.grad(corr_op.sum(), i),
theano.grad(corr_op.sum(), k),
conv_op,
theano.grad(conv_op.sum(), i),
theano.grad(conv_op.sum(), k)],
mode=theano_mode)
allvals = f(npy_img, npy_kern)
for a, b, p in zip(allvals[:3], allvals[3:], ('fprop', 'bprop img', 'bprop kern')):
if (a.shape != b.shape) or not numpy.allclose(a, b, rtol=1e-4):
print "Test failed for", p
print "mode: ", mode
print "ishape: ", ishape
print "kshape: ", kshape
print "subsample: ", subsample
assert False
def benchmark(): def benchmark():
shapes_valid = [ shapes_valid = [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论