提交 3a0dcdbc authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix last few comments about the tests.

上级 b90abd33
...@@ -195,7 +195,7 @@ def _params_allgood(ishape, kshape, mode, subsample=(1, 1), img_stride=(1, 1), ...@@ -195,7 +195,7 @@ def _params_allgood(ishape, kshape, mode, subsample=(1, 1), img_stride=(1, 1),
assert (numpy.asarray(gpuval) == numpy.asarray(gpuval2)).all() assert (numpy.asarray(gpuval) == numpy.asarray(gpuval2)).all()
gpuval = numpy.asarray(gpuval) gpuval = numpy.asarray(gpuval)
assert gpuval.shape == cpuval.shape, ("shape mismatch", gpuval.shape, cpuval.shape) assert gpuval.shape == cpuval.shape, ("shape mismatch", gpuval.shape, cpuval.shape)
assert_allclose(cpuval, gpuval, rtol=rtol) assert_allclose(cpuval, gpuval, rtol=rtol, atol=atol)
assert numpy.all(numpy.isfinite(gpuval)), gpuval assert numpy.all(numpy.isfinite(gpuval)), gpuval
if (t2 is not None): if (t2 is not None):
...@@ -852,10 +852,9 @@ def dnn_op(mode, subsample): ...@@ -852,10 +852,9 @@ def dnn_op(mode, subsample):
return f return f
def conv_grad(mode, bs, ch, nf, rImg1, rImg2, rFlt1, rFlt2, subsx, subsy, op): def conv_grad(mode, bs, ch, nf, rImg1, rImg2, rFlt1, rFlt2, subsample, op):
ishape = (bs, ch, rImg1, rImg2) ishape = (bs, ch, rImg1, rImg2)
kshape = (nf, ch, rFlt1, rFlt2) kshape = (nf, ch, rFlt1, rFlt2)
subsample = (subsx, subsy)
npy_img = theano._asarray(numpy.random.rand(*ishape), dtype='float32') npy_img = theano._asarray(numpy.random.rand(*ishape), dtype='float32')
npy_kern = theano._asarray(numpy.random.rand(*kshape), dtype='float32') npy_kern = theano._asarray(numpy.random.rand(*kshape), dtype='float32')
...@@ -913,16 +912,11 @@ def test_conv_grads(): ...@@ -913,16 +912,11 @@ def test_conv_grads():
for rImg2 in [2, 8]: for rImg2 in [2, 8]:
for rFlt1 in [1, 2]: for rFlt1 in [1, 2]:
for rFlt2 in [1, 2]: for rFlt2 in [1, 2]:
for op in [gemm_op, dnn_op]: for subsample in (1, 1), (1, 2), (2, 2):
yield (conv_grad, mode, bs, ch, nf, for op in [gemm_op, dnn_op]:
rImg1, rImg2, rFlt1, rFlt2, yield (conv_grad, mode, bs, ch, nf,
1, 1, op) rImg1, rImg2, rFlt1, rFlt2,
yield (conv_grad, mode, bs, ch, nf, subsample, op)
rImg1, rImg2, rFlt1, rFlt2,
1, 2, op)
yield (conv_grad, mode, bs, ch, nf,
rImg1, rImg2, rFlt1, rFlt2,
2, 2, op)
def benchmark(): def benchmark():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论