提交 3b57cbf3 authored 作者: Frederic Bastien's avatar Frederic Bastien

many modif to make test_conv.py:TestConvOp.test_ConvOpGrad run in 20m instead of 4h in debug mode.

上级 f9a67241
......@@ -434,13 +434,13 @@ class TestConvOp(unittest.TestCase):
print ' TEST ConvOp.grad'
print '*************************************************'
nkern = 4
bsize = 3
nkern = 3
bsize = 2
types = ["float32", "float64"]
kshps = [(3,4)]
imshps = [(2,8,7)]
kshps = [(2,3)]
imshps = [(2,3,4)]
modes = ['valid', 'full']
unroll = [(0,0),(1,1),(1,4),(3,1),(3,4)]
unroll = [(0,0),(1,1),(2,3)]
ssizes = [(1,1),(2,2)]
for typ in types:
......@@ -449,18 +449,16 @@ class TestConvOp(unittest.TestCase):
for mode in modes:
for imshp in imshps:
visdim = 1 if len(imshp)!=3 else imshp[0]
imgvals = N.array(N.random.random(N.hstack((bsize,imshp))),dtype=imgs.dtype)
for kshp in kshps:
t=numpy.array([imshp[1]-kshp[0],imshp[2]-kshp[1]])
kernvals = N.array(N.random.rand(nkern,visdim,kshp[0],
kshp[1]),dtype=kerns.dtype)
# 'full' mode should support kernels bigger than the input
if mode == 'valid' and (t<0).any():
continue
for un_b,un_k in unroll:
for ss in ssizes:
imgvals = N.array(N.random.random(N.hstack((bsize,imshp))),dtype=imgs.dtype)
kernvals = N.array(N.random.rand(nkern,visdim,kshp[0],
kshp[1]),dtype=kerns.dtype)
print 'test_ConvOpGrad'
print 'mode type:', mode, typ
print 'imshp:', imshp
......@@ -472,19 +470,15 @@ class TestConvOp(unittest.TestCase):
print 'nkern:', 4
def test_i(imgs):
out, outshp = convolve2(kernvals, kshp, nkern,
imgs, imshp, bsize,
mode=mode, step=ss,
unroll_batch=un_b,
unroll_kern=un_k)
return out
convop = ConvOp(imshp, kshp, nkern, bsize, ss[0], ss[1],
output_mode=mode, unroll_batch=un_b, unroll_kern=un_k)
return convop(imgs, kernvals)
def test_k(kerns):
out, outshp = convolve2(kerns, kshp, nkern,
imgvals, imshp, bsize,
mode=mode, step=ss,
unroll_batch=un_b,
unroll_kern=un_k)
return out
convop = ConvOp(imshp, kshp, nkern, bsize, ss[0], ss[1],
output_mode=mode, unroll_batch=un_b, unroll_kern=un_k)
return convop(imgvals, kerns)
#TODO the tolerance needed to pass is very high for float32(0.17). Is this acceptable? Expected?
utt.verify_grad(test_i, [imgvals],
cast_to_output_type=True,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论