提交 c1cabda5 authored 作者: notoraptor's avatar notoraptor

Add support for special test cases.

Add a first special test case for conv 2d fwd.
上级 490907f0
......@@ -40,6 +40,79 @@ def ifilter(function, sequence):
return (element for element in sequence if function(element))
class DnnCase:
"""
Help class to generate special test cases quickly.
"""
def __init__(self,
type, inputs_shape, filters_shape,
algo=None, dtype=None, precision=None,
subsample=None, dilation=None, border_mode='valid',
conv_mode='conv', alpha=1, beta=0,
should_fail=False):
assert type in ('fwd', 'bwd-filter', 'bwd-data')
assert len(inputs_shape) == len(filters_shape) > 2
ndim = len(inputs_shape) - 2
if dtype is None:
dtype = theano.config.floatX
if precision is None:
precision = theano.config.floatX
if subsample is None:
subsample = (1,) * ndim
if dilation is None:
dilation = (1,) * ndim
assert dtype in ('float16', 'float32', 'float64')
assert precision in ('float16', 'float32', 'float64')
assert len(subsample) == len(dilation) == ndim
assert border_mode in ('valid', 'full', 'half') or len(border_mode) == ndim
assert conv_mode in ('conv', 'cross')
assert alpha != 0
self.type = type
self.ndim = ndim
self.algo = algo
self.inputs_shape = inputs_shape
self.filters_shape = filters_shape
self.dtype = dtype
self.precision = precision
self.subsample = subsample
self.dilation = dilation
self.border_mode = border_mode
self.conv_mode = conv_mode
self.alpha = alpha
self.beta = beta
self.should_fail = bool(should_fail)
def is_fwd(self):
return self.type == 'fwd'
def is_bwd_filter(self):
return self.type == 'bwd-filter'
def is_bwd_data(self):
return self.type == 'bwd-data'
def get_case(self):
return (self.algo, self.dtype, self.precision,
(self.inputs_shape, self.filters_shape,
self.subsample, self.dilation, self.border_mode,
self.conv_mode, self.alpha, self.beta))
@staticmethod
def fwd(*args, **kwargs):
return DnnCase('fwd', *args, **kwargs)
@staticmethod
def bwd_filter(*args, **kwargs):
return DnnCase('bwd-filter', *args, **kwargs)
@staticmethod
def bwd_data(*args, **kwargs):
return DnnCase('bwd-data', *args, **kwargs)
class DnnCaseGenerator:
"""
Main class used to generate test cases.
......@@ -261,6 +334,8 @@ class BaseTestDnnConv(object):
cpu_gradinput_class = None
cpu_gradweight_class = None
special_cases = [] # List of DnnCases.
# Utility methods.
def get_cases(self):
......@@ -455,7 +530,16 @@ class BaseTestDnnConv(object):
algos = (algo for algo in self.bwd_filter_algorithms
if cudnn.bwd_filter_algo_supports_dtype_config(algo, dtype, precision, self.ndim))
count_contexts += sum(1 for algo in algos) + len(SUPPORTED_DNN_CONV_ALGO_RUNTIME)
return len_cases * count_contexts
return len(self.special_cases) + len_cases * count_contexts
def should_fail(self, callable, *args):
try:
print('(should fail)', file=sys.stderr, end=' ')
callable(*args)
except Exception:
pass
else:
raise AssertionError('Should fail', callable.__name__, *args)
# Iterable test methods.
......@@ -466,6 +550,12 @@ class BaseTestDnnConv(object):
for algo in chain(algos, SUPPORTED_DNN_CONV_ALGO_RUNTIME):
for parameters in self.get_cases():
yield (self.run_conv_fwd, algo, dtype, precision, parameters)
for dnn_case in self.special_cases:
if dnn_case.is_fwd():
if dnn_case.should_fail:
yield (self.should_fail, self.run_conv_fwd,) + dnn_case.get_case()
else:
yield (self.run_conv_fwd,) + dnn_case.get_case()
def test_gradinput(self):
for dtype, precision in cudnn.get_bwd_data_dtype_configs():
......@@ -474,6 +564,12 @@ class BaseTestDnnConv(object):
for algo in chain(algos, SUPPORTED_DNN_CONV_ALGO_RUNTIME):
for parameters in self.get_cases():
yield (self.run_conv_gradinput, algo, dtype, precision, parameters)
for dnn_case in self.special_cases:
if dnn_case.is_bwd_data():
if dnn_case.should_fail:
yield (self.should_fail, self.run_conv_gradinput,) + dnn_case.get_case()
else:
yield (self.run_conv_gradinput,) + dnn_case.get_case()
def test_gradweight(self):
for dtype, precision in cudnn.get_bwd_filter_dtype_configs():
......@@ -482,6 +578,12 @@ class BaseTestDnnConv(object):
for algo in chain(algos, SUPPORTED_DNN_CONV_ALGO_RUNTIME):
for parameters in self.get_cases():
yield (self.run_conv_gradweight, algo, dtype, precision, parameters)
for dnn_case in self.special_cases:
if dnn_case.is_bwd_filter():
if dnn_case.should_fail:
yield (self.should_fail, self.run_conv_gradweight,) + dnn_case.get_case()
else:
yield (self.run_conv_gradweight,) + dnn_case.get_case()
class TestDnnConv2D(BaseTestDnnConv):
......@@ -495,6 +597,10 @@ class TestDnnConv2D(BaseTestDnnConv):
cpu_gradinput_class = theano.tensor.nnet.corr.CorrMM_gradInputs
cpu_gradweight_class = theano.tensor.nnet.corr.CorrMM_gradWeights
special_cases = [DnnCase.bwd_filter(algo='deterministic', dtype='float32', precision='float32',
inputs_shape=(1, 1, 541211, 10), filters_shape=(50, 1, 3, 10),
border_mode=(1, 0), should_fail=True)]
class TestDnnConv3D(BaseTestDnnConv):
ndim = 3
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论