提交 8ee47394 authored 作者: notoraptor's avatar notoraptor

Allow to concatenate conv case generators.

Update cases for gradweight ff_tiling. Allow DOUBLE_CONFIG for all fft_tiling computations in cuDNN V6.
上级 ac21919b
......@@ -165,11 +165,6 @@ class CuDNNV51(object):
if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
if ndim == 2:
return is_pseudo_half_config(dtype, precision) or is_float_config(dtype, precision)
# NB: For cuDNN V6:
# " Data Type Config Support: PSEUDO_HALF_CONFIG, FLOAT_CONFIG
# (DOUBLE_CONFIG is also supported when the task can be handled by 1D FFT,
# ie, one of the filter dimension, width or height is 1)"
# Could be checked only when being in C code.
if ndim == 3:
return not is_true_half_config(dtype, precision)
if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
......@@ -210,9 +205,6 @@ class CuDNNV51(object):
if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
if ndim == 2:
return is_pseudo_half_config(dtype, precision) or is_float_config(dtype, precision)
# NB: For cuDNN V6: "(DOUBLE_CONFIG is also supported when the task can be handled by 1D FFT,
# ie, one of the filter dimension, width or height is 1)"
# Could be checked only when being in C code.
if ndim == 3:
return not is_true_half_config(dtype, precision)
if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
......@@ -265,6 +257,22 @@ class CuDNNV6(CuDNNV51):
('CUDNN_REDUCE_TENSOR_NORM2', 'norm2'),
ctype='cudnnReduceTensorOp_t')
def fwd_algo_supports_dtype_config(self, algo, dtype, precision, ndim):
is_supported = super(CuDNNV6, self).fwd_algo_supports_dtype_config(algo, dtype, precision, ndim)
if not is_supported:
algorithms = self.cudnnConvolutionFwdAlgo_t
algo = algorithms.fromalias(algo)
if algo == algorithms.CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
# NB: For cuDNN V6:
# "Data Type Config Support: PSEUDO_HALF_CONFIG, FLOAT_CONFIG
# (DOUBLE_CONFIG is also supported when the task can be handled by 1D FFT,
# ie, one of the filter dimension, width or height is 1)"
# Could be checked only in C code. By default, let's allow DOUBLE_CONFIG.
return ndim == 2 and (is_pseudo_half_config(dtype, precision) or
is_float_config(dtype, precision) or
is_double_config(dtype, precision))
return is_supported
def bwd_filter_algo_supports_dtype_config(self, algo, dtype, precision, ndim):
is_supported = super(CuDNNV6, self).bwd_filter_algo_supports_dtype_config(algo, dtype, precision, ndim)
if not is_supported:
......@@ -276,6 +284,22 @@ class CuDNNV6(CuDNNV51):
is_double_config(dtype, precision))
return is_supported
def bwd_data_algo_supports_dtype_config(self, algo, dtype, precision, ndim):
is_supported = super(CuDNNV6, self).bwd_data_algo_supports_dtype_config(algo, dtype, precision, ndim)
if not is_supported:
algorithms = self.cudnnConvolutionBwdDataAlgo_t
algo = algorithms.fromalias(algo)
if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
# NB: For cuDNN V6:
# "Data Type Config Support: PSEUDO_HALF_CONFIG, FLOAT_CONFIG
# (DOUBLE_CONFIG is also supported when the task can be handled by 1D FFT,
# ie, one of the filter dimension, width or height is 1)"
# Could be checked only in C code. By default, let's allow DOUBLE_CONFIG.
return ndim == 2 and (is_pseudo_half_config(dtype, precision) or
is_float_config(dtype, precision) or
is_double_config(dtype, precision))
return is_supported
class CuDNNV7(CuDNNV6):
version = 7
......
......@@ -16,7 +16,7 @@
from __future__ import absolute_import, print_function, division
import sys
from itertools import product
from itertools import product, chain
import nose
import numpy as np
......@@ -316,6 +316,19 @@ class ConvCaseGenerator:
all_border_modes, all_conv_modes, all_alphas, all_betas))
class ConvCaseGeneratorChain:
"""
Help class to concatenate many conv case generators.
"""
def __init__(self, *conv_case_generators):
assert all(isinstance(g, ConvCaseGenerator) for g in conv_case_generators)
self.generators = conv_case_generators
def get_cases(self, filter=None):
return chain(*[generator.get_cases(filter) for generator in self.generators])
class CuDNNV51ConvCaseGenerator(object):
"""
Helper class to generate specific test cases for every algorithm supported by cuDNN V5.1.
......@@ -430,14 +443,18 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
def _fwd_fft_tiling(self, ndim):
if ndim == 2:
filters_sizes = [(32, 5), (256, 1), (10, 10), (5, 1)]
subsamples = [(1, 1)]
borders = [(1, 1), (2, 1)]
return ConvCaseGenerator(ndim=ndim,
filters_sizes=filters_sizes,
subsamples=subsamples,
borders=borders,
dilations=self._dilations(ndim))
# wDesc's filter height must be greater than convDesc's zero-padding height
# wDesc's filter width must be greater than convDesc's zero-padding width
filters_sizes = [(32, 5), (10, 10)]
borders = [(1, 1), (6, 4)]
generator1 = ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim), subsamples=subsamples,
filters_sizes=filters_sizes, borders=borders)
filters_sizes = [(256, 1), (5, 1)]
borders = [(1, 0), (2, 0)]
generator2 = ConvCaseGenerator(ndim=ndim, dilations=self._dilations(ndim), subsamples=subsamples,
filters_sizes=filters_sizes, borders=borders)
return ConvCaseGeneratorChain(generator1, generator2)
if ndim == 3:
return super(CuDNNV6ConvCaseGenerator, self)._fwd_fft_tiling(ndim)
......@@ -445,10 +462,10 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
return self._fwd_none(ndim)
def _gw_fft_tiling(self, ndim):
inputs_sizes = [(256, 1), (20, 1)]
inputs_sizes = [(247, 1), (20, 1)]
filters_sizes = [(3, 1), (10, 1)]
subsamples = [(1,) * ndim]
borders = [(1, 1), (2, 1)]
borders = [(1, 0), (2, 0)]
return ConvCaseGenerator(ndim=ndim,
inputs_sizes=inputs_sizes,
filters_sizes=filters_sizes,
......@@ -467,6 +484,8 @@ class CuDNNV6ConvCaseGenerator(CuDNNV51ConvCaseGenerator):
def gw(self, algo, ndim):
if algo == self.NONE:
return self._gw_none(ndim)
if algo == self.FFT_TILING:
return self._gw_fft_tiling(ndim)
return super(CuDNNV6ConvCaseGenerator, self).gw(algo, ndim)
def gi(self, algo, ndim):
......
......@@ -126,8 +126,8 @@ if args.algo not in SUPPORTED_DNN_CONV_ALGO_RUNTIME:
if test == BWD_DATA:
check_config = cudnn.bwd_data_algo_supports_dtype_config(args.algo, args.dtype, args.precision, ndim)
if not check_config:
raise ValueError('%s computation does not support configuration (%s, %s) for algo %s.' % (
test, args.dtype, args.precision, args.algo))
print('Warning: %s computation does not normally support configuration (%s, %s) for algo %s.' % (
test, args.dtype, args.precision, args.algo), file=sys.stderr)
algo = args.algo
dtype = args.dtype
precision = args.precision
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论