提交 ef2dfc15 authored 作者: Frederic's avatar Frederic

pep8

上级 60c64280
......@@ -635,7 +635,8 @@ def test_valid(conv_gemm=False):
# Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm")
cls = cuda.blas.GpuCorrMM
version = [-1] # dummy version; not used by GpuCorrMM so one version is enough
# dummy version; not used by GpuCorrMM so one version is enough
version = [-1]
# Add tests with strided inputs by still square images and filters.
shapes += get_shapes2(scales_img=(2, 2), img_stride=(2, 2))
shapes += get_shapes2(scales_kern=(2, 2), kern_stride=(2, 2))
......@@ -645,6 +646,7 @@ def test_valid(conv_gemm=False):
print_=print_, ones=ones, rtol=1.1e-5,
theano_mode=mode, cls=cls)
def test_gemm_valid():
test_valid(conv_gemm=True)
......@@ -712,12 +714,14 @@ def test_full(conv_gemm=False):
# Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm")
cls = cuda.blas.GpuCorrMM
version = [-1] # dummy version; not used by GpuCorrMM so one version is enough
# dummy version; not used by GpuCorrMM so one version is enough
version = [-1]
else:
mode = cls = None
exec_conv(version, shapes, verbose, random, 'full',
theano_mode=mode, cls=cls)
def test_gemm_full():
test_full(conv_gemm=True)
......@@ -735,7 +739,8 @@ def test_subsample(conv_gemm=False):
shapes += get_shapes2(scales_img=(2, 2), subsample=(2, 1))
shapes += get_shapes2(scales_img=(2, 2), subsample=(2, 2))
#We put only the version that implement the subsample to make the test faster.
# We put only the version that implement the subsample to make the
# test faster.
version_valid = [-2, -1, 1, 3, 11, 12]
version_full = [-2, -1]
verbose = 0
......@@ -749,7 +754,8 @@ def test_subsample(conv_gemm=False):
# Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm")
cls = cuda.blas.GpuCorrMM
version_valid = version_full = [-1] # dummy version; not used by GpuCorrMM so one version is enough
# dummy version; not used by GpuCorrMM so one version is enough
version_valid = version_full = [-1]
else:
mode = cls = None
......@@ -760,6 +766,7 @@ def test_subsample(conv_gemm=False):
print_=print_, ones=ones,
theano_mode=mode, cls=cls)
def test_gemm_subsample():
test_subsample(conv_gemm=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论