提交 60fa63e1 authored 作者: notoraptor's avatar notoraptor

Better handling for tests that should fail.

Fix test cases generators to avoid cases that are intended to fail. Fix messages and properly exit from CUDA in cuDNN conv ops.
上级 2aff6d77
......@@ -348,7 +348,7 @@ class CuDNNV51ConvCaseGenerator(object):
def _fwd_fft(self, ndim):
inputs_sizes = [(10,) * ndim,
(248, 5) + (2,) * (ndim - 2)]
(240, 5) + (2,) * (ndim - 2)]
filters_sizes = [tuple(range(9, 9 - ndim, -1))]
subsamples = [(1,) * ndim]
return ConvCaseGenerator(ndim=ndim,
......@@ -357,7 +357,7 @@ class CuDNNV51ConvCaseGenerator(object):
subsamples=subsamples,
dilations=self._dilations(ndim))
def _fwd_fft_tiling(self, ndim):
def _fwd_fft_tiling(self, ndim, dtype, precision):
if ndim == 2:
filters_sizes = [(32, 5)]
if ndim == 3:
......@@ -395,8 +395,8 @@ class CuDNNV51ConvCaseGenerator(object):
def _gi_fft(self, ndim):
return self._fwd_fft(ndim)
def _gi_fft_tiling(self, ndim):
return self._fwd_fft_tiling(ndim)
def _gi_fft_tiling(self, ndim, dtype, precision):
return self._fwd_fft_tiling(ndim, dtype, precision)
def _gi_winograd(self, ndim):
return self._fwd_winograd(ndim)
......@@ -406,29 +406,29 @@ class CuDNNV51ConvCaseGenerator(object):
# Public interface.
def fwd(self, algo, ndim):
def fwd(self, algo, ndim, dtype, precision):
if algo == self.FFT:
return self._fwd_fft(ndim)
if algo == self.FFT_TILING:
return self._fwd_fft_tiling(ndim)
return self._fwd_fft_tiling(ndim, dtype, precision)
if algo == self.WINOGRAD:
return self._fwd_winograd(ndim)
if algo == self.WINOGRAD_NON_FUSED:
return self._fwd_winograd_non_fused(ndim)
return ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim))
def gw(self, algo, ndim):
def gw(self, algo, ndim, dtype, precision):
if algo == self.FFT:
return self._gw_fft(ndim)
if algo == self.WINOGRAD_NON_FUSED:
return self._gw_winograd_non_fused(ndim)
return ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim))
def gi(self, algo, ndim):
def gi(self, algo, ndim, dtype, precision):
if algo == self.FFT:
return self._gi_fft(ndim)
if algo == self.FFT_TILING:
return self._gi_fft_tiling(ndim)
return self._gi_fft_tiling(ndim, dtype, precision)
if algo == self.WINOGRAD:
return self._gi_winograd(ndim)
if algo == self.WINOGRAD_NON_FUSED:
......@@ -441,22 +441,25 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
# All dilations allowed.
return ConvCaseGenerator(ndim=ndim)
def _fwd_fft_tiling(self, ndim):
def _fwd_fft_tiling(self, ndim, dtype, precision):
if ndim == 2:
subsamples = [(1, 1)]
# wDesc's filter height must be greater than convDesc's zero-padding height
# wDesc's filter width must be greater than convDesc's zero-padding width
filters_sizes = [(32, 5), (10, 10)]
borders = [(1, 1), (6, 4)]
generator1 = ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim), subsamples=subsamples,
filters_sizes=filters_sizes, borders=borders)
generators = []
if (dtype, precision) != ('float64', 'float64'):
# Filter sizes with every dimension != 1 is not supported for DOUBLE_CONFIG.
filters_sizes = [(32, 5), (10, 10)]
borders = [(1, 1), (6, 4)]
generators += [ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim), subsamples=subsamples,
filters_sizes=filters_sizes, borders=borders)]
filters_sizes = [(256, 1), (5, 1)]
borders = [(1, 0), (2, 0)]
generator2 = ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim), subsamples=subsamples,
filters_sizes=filters_sizes, borders=borders)
return ConvCaseGeneratorChain(generator1, generator2)
generators += [ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim), subsamples=subsamples,
filters_sizes=filters_sizes, borders=borders)]
return ConvCaseGeneratorChain(*generators)
if ndim == 3:
return super(CuDNNV6ConvCaseGenerator, self)._fwd_fft_tiling(ndim)
return super(CuDNNV6ConvCaseGenerator, self)._fwd_fft_tiling(ndim, dtype, precision)
def _gw_none(self, ndim):
return self._fwd_none(ndim)
......@@ -476,22 +479,22 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
def _gi_none(self, ndim):
return self._fwd_none(ndim)
def fwd(self, algo, ndim):
def fwd(self, algo, ndim, dtype, precision):
if algo == self.NONE:
return self._fwd_none(ndim)
return super(CuDNNV6ConvCaseGenerator, self).fwd(algo, ndim)
return super(CuDNNV6ConvCaseGenerator, self).fwd(algo, ndim, dtype, precision)
def gw(self, algo, ndim):
def gw(self, algo, ndim, dtype, precision):
if algo == self.NONE:
return self._gw_none(ndim)
if algo == self.FFT_TILING:
return self._gw_fft_tiling(ndim)
return super(CuDNNV6ConvCaseGenerator, self).gw(algo, ndim)
return super(CuDNNV6ConvCaseGenerator, self).gw(algo, ndim, dtype, precision)
def gi(self, algo, ndim):
def gi(self, algo, ndim, dtype, precision):
if algo == self.NONE:
return self._gi_none(ndim)
return super(CuDNNV6ConvCaseGenerator, self).gi(algo, ndim)
return super(CuDNNV6ConvCaseGenerator, self).gi(algo, ndim, dtype, precision)
cudnn_conv_case_generator = CuDNNV51ConvCaseGenerator() if cudnn.version < 6 else CuDNNV6ConvCaseGenerator()
......@@ -692,15 +695,24 @@ class BaseTestDnnConv(object):
else:
utt.assert_allclose(alpha * res_ref + beta * filters_val, res, rtol=rtol)
def should_fail(self, callable, *args):
def should_fail(self, function, *args):
try:
print('(should fail)', file=sys.stderr, end=' ')
callable(*args)
function(*args)
except Exception:
pass
else:
raise AssertionError('Should fail', callable.__name__, *args)
def should_fail_fwd(self, *args):
self.should_fail(self.run_conv_fwd, *args)
def should_fail_gradinput(self, *args):
self.should_fail(self.run_conv_gradinput, *args)
def should_fail_gradweight(self, *args):
self.should_fail(self.run_conv_gradweight, *args)
def get_expected_tcount(self):
"""
Utility function to get expected test count
......@@ -717,7 +729,7 @@ class BaseTestDnnConv(object):
algos = (algo for algo in self.fwd_algorithms
if cudnn.fwd_algo_supports_dtype_config(algo, dtype, precision, self.ndim))
for algo in algos:
for parameters in cudnn_conv_case_generator.fwd(algo, self.ndim).get_cases():
for parameters in cudnn_conv_case_generator.fwd(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_fwd, algo, dtype, precision, parameters)
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases():
......@@ -725,7 +737,7 @@ class BaseTestDnnConv(object):
for dnn_case in self.special_cases:
if dnn_case.is_fwd():
if dnn_case.should_fail:
yield (self.should_fail, self.run_conv_fwd,) + dnn_case.get_case()
yield (self.should_fail_fwd,) + dnn_case.get_case()
else:
yield (self.run_conv_fwd,) + dnn_case.get_case()
......@@ -734,7 +746,7 @@ class BaseTestDnnConv(object):
algos = (algo for algo in self.bwd_data_algorithms
if cudnn.bwd_data_algo_supports_dtype_config(algo, dtype, precision, self.ndim))
for algo in algos:
for parameters in cudnn_conv_case_generator.gi(algo, self.ndim).get_cases():
for parameters in cudnn_conv_case_generator.gi(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_gradinput, algo, dtype, precision, parameters)
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases():
......@@ -742,7 +754,7 @@ class BaseTestDnnConv(object):
for dnn_case in self.special_cases:
if dnn_case.is_bwd_data():
if dnn_case.should_fail:
yield (self.should_fail, self.run_conv_gradinput,) + dnn_case.get_case()
yield (self.should_fail_gradinput,) + dnn_case.get_case()
else:
yield (self.run_conv_gradinput,) + dnn_case.get_case()
......@@ -751,7 +763,7 @@ class BaseTestDnnConv(object):
algos = (algo for algo in self.bwd_filter_algorithms
if cudnn.bwd_filter_algo_supports_dtype_config(algo, dtype, precision, self.ndim))
for algo in algos:
for parameters in cudnn_conv_case_generator.gw(algo, self.ndim).get_cases():
for parameters in cudnn_conv_case_generator.gw(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_gradweight, algo, dtype, precision, parameters)
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases():
......@@ -759,7 +771,7 @@ class BaseTestDnnConv(object):
for dnn_case in self.special_cases:
if dnn_case.is_bwd_filter():
if dnn_case.should_fail:
yield (self.should_fail, self.run_conv_gradweight,) + dnn_case.get_case()
yield (self.should_fail_gradweight,) + dnn_case.get_case()
else:
yield (self.run_conv_gradweight,) + dnn_case.get_case()
......
......@@ -9,6 +9,7 @@ import sys
import theano
from theano.configdefaults import SUPPORTED_DNN_CONV_ALGO_RUNTIME
from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.gpuarray.cudnn_defs import (HALF, FLOAT, DOUBLE,
TRUE_HALF_CONFIG, PSEUDO_HALF_CONFIG, FLOAT_CONFIG, DOUBLE_CONFIG)
from theano.gpuarray.tests.check_dnn_conv import (cudnn, TestDnnConv2D, TestDnnConv3D, CheckDnn)
......@@ -144,4 +145,6 @@ if test == BWD_FILTER:
tests.run_conv_gradweight(algo, dtype, precision, parameters)
if test == BWD_DATA:
tests.run_conv_gradinput(algo, dtype, precision, parameters)
print('Output shape:', get_conv_output_shape(args.input_shape, args.filter_shape, args.border_mode,
args.subsample, args.dilation))
print('... OK')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论