提交 ef2dfc15 authored 作者: Frederic's avatar Frederic

pep8

上级 60c64280
...@@ -635,7 +635,8 @@ def test_valid(conv_gemm=False): ...@@ -635,7 +635,8 @@ def test_valid(conv_gemm=False):
# Test the GpuCorrMM version # Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm") mode = theano_mode.including("conv_gemm")
cls = cuda.blas.GpuCorrMM cls = cuda.blas.GpuCorrMM
version = [-1] # dummy version; not used by GpuCorrMM so one version is enough # dummy version; not used by GpuCorrMM so one version is enough
version = [-1]
# Add tests with strided inputs by still square images and filters. # Add tests with strided inputs by still square images and filters.
shapes += get_shapes2(scales_img=(2, 2), img_stride=(2, 2)) shapes += get_shapes2(scales_img=(2, 2), img_stride=(2, 2))
shapes += get_shapes2(scales_kern=(2, 2), kern_stride=(2, 2)) shapes += get_shapes2(scales_kern=(2, 2), kern_stride=(2, 2))
...@@ -645,6 +646,7 @@ def test_valid(conv_gemm=False): ...@@ -645,6 +646,7 @@ def test_valid(conv_gemm=False):
print_=print_, ones=ones, rtol=1.1e-5, print_=print_, ones=ones, rtol=1.1e-5,
theano_mode=mode, cls=cls) theano_mode=mode, cls=cls)
def test_gemm_valid(): def test_gemm_valid():
test_valid(conv_gemm=True) test_valid(conv_gemm=True)
...@@ -712,12 +714,14 @@ def test_full(conv_gemm=False): ...@@ -712,12 +714,14 @@ def test_full(conv_gemm=False):
# Test the GpuCorrMM version # Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm") mode = theano_mode.including("conv_gemm")
cls = cuda.blas.GpuCorrMM cls = cuda.blas.GpuCorrMM
version = [-1] # dummy version; not used by GpuCorrMM so one version is enough # dummy version; not used by GpuCorrMM so one version is enough
version = [-1]
else: else:
mode = cls = None mode = cls = None
exec_conv(version, shapes, verbose, random, 'full', exec_conv(version, shapes, verbose, random, 'full',
theano_mode=mode, cls=cls) theano_mode=mode, cls=cls)
def test_gemm_full(): def test_gemm_full():
test_full(conv_gemm=True) test_full(conv_gemm=True)
...@@ -735,7 +739,8 @@ def test_subsample(conv_gemm=False): ...@@ -735,7 +739,8 @@ def test_subsample(conv_gemm=False):
shapes += get_shapes2(scales_img=(2, 2), subsample=(2, 1)) shapes += get_shapes2(scales_img=(2, 2), subsample=(2, 1))
shapes += get_shapes2(scales_img=(2, 2), subsample=(2, 2)) shapes += get_shapes2(scales_img=(2, 2), subsample=(2, 2))
#We put only the version that implement the subsample to make the test faster. # We put only the version that implement the subsample to make the
# test faster.
version_valid = [-2, -1, 1, 3, 11, 12] version_valid = [-2, -1, 1, 3, 11, 12]
version_full = [-2, -1] version_full = [-2, -1]
verbose = 0 verbose = 0
...@@ -749,7 +754,8 @@ def test_subsample(conv_gemm=False): ...@@ -749,7 +754,8 @@ def test_subsample(conv_gemm=False):
# Test the GpuCorrMM version # Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm") mode = theano_mode.including("conv_gemm")
cls = cuda.blas.GpuCorrMM cls = cuda.blas.GpuCorrMM
version_valid = version_full = [-1] # dummy version; not used by GpuCorrMM so one version is enough # dummy version; not used by GpuCorrMM so one version is enough
version_valid = version_full = [-1]
else: else:
mode = cls = None mode = cls = None
...@@ -760,6 +766,7 @@ def test_subsample(conv_gemm=False): ...@@ -760,6 +766,7 @@ def test_subsample(conv_gemm=False):
print_=print_, ones=ones, print_=print_, ones=ones,
theano_mode=mode, cls=cls) theano_mode=mode, cls=cls)
def test_gemm_subsample(): def test_gemm_subsample():
test_subsample(conv_gemm=True) test_subsample(conv_gemm=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论