提交 60fa63e1 authored 作者: notoraptor's avatar notoraptor

Better handling for tests that should fail.

Fix test cases generators to avoid cases that are intended to fail. Fix messages and properly exit from CUDA in cuDNN conv ops.
上级 2aff6d77
...@@ -348,7 +348,7 @@ class CuDNNV51ConvCaseGenerator(object): ...@@ -348,7 +348,7 @@ class CuDNNV51ConvCaseGenerator(object):
def _fwd_fft(self, ndim): def _fwd_fft(self, ndim):
inputs_sizes = [(10,) * ndim, inputs_sizes = [(10,) * ndim,
(248, 5) + (2,) * (ndim - 2)] (240, 5) + (2,) * (ndim - 2)]
filters_sizes = [tuple(range(9, 9 - ndim, -1))] filters_sizes = [tuple(range(9, 9 - ndim, -1))]
subsamples = [(1,) * ndim] subsamples = [(1,) * ndim]
return ConvCaseGenerator(ndim=ndim, return ConvCaseGenerator(ndim=ndim,
...@@ -357,7 +357,7 @@ class CuDNNV51ConvCaseGenerator(object): ...@@ -357,7 +357,7 @@ class CuDNNV51ConvCaseGenerator(object):
subsamples=subsamples, subsamples=subsamples,
dilations=self._dilations(ndim)) dilations=self._dilations(ndim))
def _fwd_fft_tiling(self, ndim): def _fwd_fft_tiling(self, ndim, dtype, precision):
if ndim == 2: if ndim == 2:
filters_sizes = [(32, 5)] filters_sizes = [(32, 5)]
if ndim == 3: if ndim == 3:
...@@ -395,8 +395,8 @@ class CuDNNV51ConvCaseGenerator(object): ...@@ -395,8 +395,8 @@ class CuDNNV51ConvCaseGenerator(object):
def _gi_fft(self, ndim): def _gi_fft(self, ndim):
return self._fwd_fft(ndim) return self._fwd_fft(ndim)
def _gi_fft_tiling(self, ndim): def _gi_fft_tiling(self, ndim, dtype, precision):
return self._fwd_fft_tiling(ndim) return self._fwd_fft_tiling(ndim, dtype, precision)
def _gi_winograd(self, ndim): def _gi_winograd(self, ndim):
return self._fwd_winograd(ndim) return self._fwd_winograd(ndim)
...@@ -406,29 +406,29 @@ class CuDNNV51ConvCaseGenerator(object): ...@@ -406,29 +406,29 @@ class CuDNNV51ConvCaseGenerator(object):
# Public interface. # Public interface.
def fwd(self, algo, ndim): def fwd(self, algo, ndim, dtype, precision):
if algo == self.FFT: if algo == self.FFT:
return self._fwd_fft(ndim) return self._fwd_fft(ndim)
if algo == self.FFT_TILING: if algo == self.FFT_TILING:
return self._fwd_fft_tiling(ndim) return self._fwd_fft_tiling(ndim, dtype, precision)
if algo == self.WINOGRAD: if algo == self.WINOGRAD:
return self._fwd_winograd(ndim) return self._fwd_winograd(ndim)
if algo == self.WINOGRAD_NON_FUSED: if algo == self.WINOGRAD_NON_FUSED:
return self._fwd_winograd_non_fused(ndim) return self._fwd_winograd_non_fused(ndim)
return ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim)) return ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim))
def gw(self, algo, ndim): def gw(self, algo, ndim, dtype, precision):
if algo == self.FFT: if algo == self.FFT:
return self._gw_fft(ndim) return self._gw_fft(ndim)
if algo == self.WINOGRAD_NON_FUSED: if algo == self.WINOGRAD_NON_FUSED:
return self._gw_winograd_non_fused(ndim) return self._gw_winograd_non_fused(ndim)
return ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim)) return ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim))
def gi(self, algo, ndim): def gi(self, algo, ndim, dtype, precision):
if algo == self.FFT: if algo == self.FFT:
return self._gi_fft(ndim) return self._gi_fft(ndim)
if algo == self.FFT_TILING: if algo == self.FFT_TILING:
return self._gi_fft_tiling(ndim) return self._gi_fft_tiling(ndim, dtype, precision)
if algo == self.WINOGRAD: if algo == self.WINOGRAD:
return self._gi_winograd(ndim) return self._gi_winograd(ndim)
if algo == self.WINOGRAD_NON_FUSED: if algo == self.WINOGRAD_NON_FUSED:
...@@ -441,22 +441,25 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator): ...@@ -441,22 +441,25 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
# All dilations allowed. # All dilations allowed.
return ConvCaseGenerator(ndim=ndim) return ConvCaseGenerator(ndim=ndim)
def _fwd_fft_tiling(self, ndim): def _fwd_fft_tiling(self, ndim, dtype, precision):
if ndim == 2: if ndim == 2:
subsamples = [(1, 1)] subsamples = [(1, 1)]
# wDesc's filter height must be greater than convDesc's zero-padding height # wDesc's filter height must be greater than convDesc's zero-padding height
# wDesc's filter width must be greater than convDesc's zero-padding width # wDesc's filter width must be greater than convDesc's zero-padding width
generators = []
if (dtype, precision) != ('float64', 'float64'):
# Filter sizes with every dimension != 1 is not supported for DOUBLE_CONFIG.
filters_sizes = [(32, 5), (10, 10)] filters_sizes = [(32, 5), (10, 10)]
borders = [(1, 1), (6, 4)] borders = [(1, 1), (6, 4)]
generator1 = ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim), subsamples=subsamples, generators += [ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim), subsamples=subsamples,
filters_sizes=filters_sizes, borders=borders) filters_sizes=filters_sizes, borders=borders)]
filters_sizes = [(256, 1), (5, 1)] filters_sizes = [(256, 1), (5, 1)]
borders = [(1, 0), (2, 0)] borders = [(1, 0), (2, 0)]
generator2 = ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim), subsamples=subsamples, generators += [ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim), subsamples=subsamples,
filters_sizes=filters_sizes, borders=borders) filters_sizes=filters_sizes, borders=borders)]
return ConvCaseGeneratorChain(generator1, generator2) return ConvCaseGeneratorChain(*generators)
if ndim == 3: if ndim == 3:
return super(CuDNNV6ConvCaseGenerator, self)._fwd_fft_tiling(ndim) return super(CuDNNV6ConvCaseGenerator, self)._fwd_fft_tiling(ndim, dtype, precision)
def _gw_none(self, ndim): def _gw_none(self, ndim):
return self._fwd_none(ndim) return self._fwd_none(ndim)
...@@ -476,22 +479,22 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator): ...@@ -476,22 +479,22 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
def _gi_none(self, ndim): def _gi_none(self, ndim):
return self._fwd_none(ndim) return self._fwd_none(ndim)
def fwd(self, algo, ndim): def fwd(self, algo, ndim, dtype, precision):
if algo == self.NONE: if algo == self.NONE:
return self._fwd_none(ndim) return self._fwd_none(ndim)
return super(CuDNNV6ConvCaseGenerator, self).fwd(algo, ndim) return super(CuDNNV6ConvCaseGenerator, self).fwd(algo, ndim, dtype, precision)
def gw(self, algo, ndim): def gw(self, algo, ndim, dtype, precision):
if algo == self.NONE: if algo == self.NONE:
return self._gw_none(ndim) return self._gw_none(ndim)
if algo == self.FFT_TILING: if algo == self.FFT_TILING:
return self._gw_fft_tiling(ndim) return self._gw_fft_tiling(ndim)
return super(CuDNNV6ConvCaseGenerator, self).gw(algo, ndim) return super(CuDNNV6ConvCaseGenerator, self).gw(algo, ndim, dtype, precision)
def gi(self, algo, ndim): def gi(self, algo, ndim, dtype, precision):
if algo == self.NONE: if algo == self.NONE:
return self._gi_none(ndim) return self._gi_none(ndim)
return super(CuDNNV6ConvCaseGenerator, self).gi(algo, ndim) return super(CuDNNV6ConvCaseGenerator, self).gi(algo, ndim, dtype, precision)
cudnn_conv_case_generator = CuDNNV51ConvCaseGenerator() if cudnn.version < 6 else CuDNNV6ConvCaseGenerator() cudnn_conv_case_generator = CuDNNV51ConvCaseGenerator() if cudnn.version < 6 else CuDNNV6ConvCaseGenerator()
...@@ -692,15 +695,24 @@ class BaseTestDnnConv(object): ...@@ -692,15 +695,24 @@ class BaseTestDnnConv(object):
else: else:
utt.assert_allclose(alpha * res_ref + beta * filters_val, res, rtol=rtol) utt.assert_allclose(alpha * res_ref + beta * filters_val, res, rtol=rtol)
def should_fail(self, callable, *args): def should_fail(self, function, *args):
try: try:
print('(should fail)', file=sys.stderr, end=' ') print('(should fail)', file=sys.stderr, end=' ')
callable(*args) function(*args)
except Exception: except Exception:
pass pass
else: else:
raise AssertionError('Should fail', callable.__name__, *args) raise AssertionError('Should fail', callable.__name__, *args)
def should_fail_fwd(self, *args):
self.should_fail(self.run_conv_fwd, *args)
def should_fail_gradinput(self, *args):
self.should_fail(self.run_conv_gradinput, *args)
def should_fail_gradweight(self, *args):
self.should_fail(self.run_conv_gradweight, *args)
def get_expected_tcount(self): def get_expected_tcount(self):
""" """
Utility function to get expected test count Utility function to get expected test count
...@@ -717,7 +729,7 @@ class BaseTestDnnConv(object): ...@@ -717,7 +729,7 @@ class BaseTestDnnConv(object):
algos = (algo for algo in self.fwd_algorithms algos = (algo for algo in self.fwd_algorithms
if cudnn.fwd_algo_supports_dtype_config(algo, dtype, precision, self.ndim)) if cudnn.fwd_algo_supports_dtype_config(algo, dtype, precision, self.ndim))
for algo in algos: for algo in algos:
for parameters in cudnn_conv_case_generator.fwd(algo, self.ndim).get_cases(): for parameters in cudnn_conv_case_generator.fwd(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_fwd, algo, dtype, precision, parameters) yield (self.run_conv_fwd, algo, dtype, precision, parameters)
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME: for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases(): for parameters in self.get_cases():
...@@ -725,7 +737,7 @@ class BaseTestDnnConv(object): ...@@ -725,7 +737,7 @@ class BaseTestDnnConv(object):
for dnn_case in self.special_cases: for dnn_case in self.special_cases:
if dnn_case.is_fwd(): if dnn_case.is_fwd():
if dnn_case.should_fail: if dnn_case.should_fail:
yield (self.should_fail, self.run_conv_fwd,) + dnn_case.get_case() yield (self.should_fail_fwd,) + dnn_case.get_case()
else: else:
yield (self.run_conv_fwd,) + dnn_case.get_case() yield (self.run_conv_fwd,) + dnn_case.get_case()
...@@ -734,7 +746,7 @@ class BaseTestDnnConv(object): ...@@ -734,7 +746,7 @@ class BaseTestDnnConv(object):
algos = (algo for algo in self.bwd_data_algorithms algos = (algo for algo in self.bwd_data_algorithms
if cudnn.bwd_data_algo_supports_dtype_config(algo, dtype, precision, self.ndim)) if cudnn.bwd_data_algo_supports_dtype_config(algo, dtype, precision, self.ndim))
for algo in algos: for algo in algos:
for parameters in cudnn_conv_case_generator.gi(algo, self.ndim).get_cases(): for parameters in cudnn_conv_case_generator.gi(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_gradinput, algo, dtype, precision, parameters) yield (self.run_conv_gradinput, algo, dtype, precision, parameters)
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME: for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases(): for parameters in self.get_cases():
...@@ -742,7 +754,7 @@ class BaseTestDnnConv(object): ...@@ -742,7 +754,7 @@ class BaseTestDnnConv(object):
for dnn_case in self.special_cases: for dnn_case in self.special_cases:
if dnn_case.is_bwd_data(): if dnn_case.is_bwd_data():
if dnn_case.should_fail: if dnn_case.should_fail:
yield (self.should_fail, self.run_conv_gradinput,) + dnn_case.get_case() yield (self.should_fail_gradinput,) + dnn_case.get_case()
else: else:
yield (self.run_conv_gradinput,) + dnn_case.get_case() yield (self.run_conv_gradinput,) + dnn_case.get_case()
...@@ -751,7 +763,7 @@ class BaseTestDnnConv(object): ...@@ -751,7 +763,7 @@ class BaseTestDnnConv(object):
algos = (algo for algo in self.bwd_filter_algorithms algos = (algo for algo in self.bwd_filter_algorithms
if cudnn.bwd_filter_algo_supports_dtype_config(algo, dtype, precision, self.ndim)) if cudnn.bwd_filter_algo_supports_dtype_config(algo, dtype, precision, self.ndim))
for algo in algos: for algo in algos:
for parameters in cudnn_conv_case_generator.gw(algo, self.ndim).get_cases(): for parameters in cudnn_conv_case_generator.gw(algo, self.ndim, dtype, precision).get_cases():
yield (self.run_conv_gradweight, algo, dtype, precision, parameters) yield (self.run_conv_gradweight, algo, dtype, precision, parameters)
for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME: for algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
for parameters in self.get_cases(): for parameters in self.get_cases():
...@@ -759,7 +771,7 @@ class BaseTestDnnConv(object): ...@@ -759,7 +771,7 @@ class BaseTestDnnConv(object):
for dnn_case in self.special_cases: for dnn_case in self.special_cases:
if dnn_case.is_bwd_filter(): if dnn_case.is_bwd_filter():
if dnn_case.should_fail: if dnn_case.should_fail:
yield (self.should_fail, self.run_conv_gradweight,) + dnn_case.get_case() yield (self.should_fail_gradweight,) + dnn_case.get_case()
else: else:
yield (self.run_conv_gradweight,) + dnn_case.get_case() yield (self.run_conv_gradweight,) + dnn_case.get_case()
......
...@@ -9,6 +9,7 @@ import sys ...@@ -9,6 +9,7 @@ import sys
import theano import theano
from theano.configdefaults import SUPPORTED_DNN_CONV_ALGO_RUNTIME from theano.configdefaults import SUPPORTED_DNN_CONV_ALGO_RUNTIME
from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.gpuarray.cudnn_defs import (HALF, FLOAT, DOUBLE, from theano.gpuarray.cudnn_defs import (HALF, FLOAT, DOUBLE,
TRUE_HALF_CONFIG, PSEUDO_HALF_CONFIG, FLOAT_CONFIG, DOUBLE_CONFIG) TRUE_HALF_CONFIG, PSEUDO_HALF_CONFIG, FLOAT_CONFIG, DOUBLE_CONFIG)
from theano.gpuarray.tests.check_dnn_conv import (cudnn, TestDnnConv2D, TestDnnConv3D, CheckDnn) from theano.gpuarray.tests.check_dnn_conv import (cudnn, TestDnnConv2D, TestDnnConv3D, CheckDnn)
...@@ -144,4 +145,6 @@ if test == BWD_FILTER: ...@@ -144,4 +145,6 @@ if test == BWD_FILTER:
tests.run_conv_gradweight(algo, dtype, precision, parameters) tests.run_conv_gradweight(algo, dtype, precision, parameters)
if test == BWD_DATA: if test == BWD_DATA:
tests.run_conv_gradinput(algo, dtype, precision, parameters) tests.run_conv_gradinput(algo, dtype, precision, parameters)
print('Output shape:', get_conv_output_shape(args.input_shape, args.filter_shape, args.border_mode,
args.subsample, args.dilation))
print('... OK') print('... OK')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论