提交 3a0dcdbc authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix last few comments about the tests.

上级 b90abd33
......@@ -195,7 +195,7 @@ def _params_allgood(ishape, kshape, mode, subsample=(1, 1), img_stride=(1, 1),
assert (numpy.asarray(gpuval) == numpy.asarray(gpuval2)).all()
gpuval = numpy.asarray(gpuval)
assert gpuval.shape == cpuval.shape, ("shape mismatch", gpuval.shape, cpuval.shape)
assert_allclose(cpuval, gpuval, rtol=rtol)
assert_allclose(cpuval, gpuval, rtol=rtol, atol=atol)
assert numpy.all(numpy.isfinite(gpuval)), gpuval
if (t2 is not None):
......@@ -852,10 +852,9 @@ def dnn_op(mode, subsample):
return f
def conv_grad(mode, bs, ch, nf, rImg1, rImg2, rFlt1, rFlt2, subsx, subsy, op):
def conv_grad(mode, bs, ch, nf, rImg1, rImg2, rFlt1, rFlt2, subsample, op):
ishape = (bs, ch, rImg1, rImg2)
kshape = (nf, ch, rFlt1, rFlt2)
subsample = (subsx, subsy)
npy_img = theano._asarray(numpy.random.rand(*ishape), dtype='float32')
npy_kern = theano._asarray(numpy.random.rand(*kshape), dtype='float32')
......@@ -913,16 +912,11 @@ def test_conv_grads():
for rImg2 in [2, 8]:
for rFlt1 in [1, 2]:
for rFlt2 in [1, 2]:
for op in [gemm_op, dnn_op]:
yield (conv_grad, mode, bs, ch, nf,
rImg1, rImg2, rFlt1, rFlt2,
1, 1, op)
yield (conv_grad, mode, bs, ch, nf,
rImg1, rImg2, rFlt1, rFlt2,
1, 2, op)
yield (conv_grad, mode, bs, ch, nf,
rImg1, rImg2, rFlt1, rFlt2,
2, 2, op)
for subsample in (1, 1), (1, 2), (2, 2):
for op in [gemm_op, dnn_op]:
yield (conv_grad, mode, bs, ch, nf,
rImg1, rImg2, rFlt1, rFlt2,
subsample, op)
def benchmark():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论