提交 2aff6d77 authored 作者: notoraptor's avatar notoraptor

Simplify script `run_dnn_conv.py`.

上级 8ee47394
...@@ -35,7 +35,6 @@ if __name__ != '__main__': ...@@ -35,7 +35,6 @@ if __name__ != '__main__':
args = sys.argv[1:] args = sys.argv[1:]
computations = FWD, BWD_FILTER, BWD_DATA = ('fwd', 'gradweight', 'gradinput') computations = FWD, BWD_FILTER, BWD_DATA = ('fwd', 'gradweight', 'gradinput')
dimensions = ('2D', '2d', '3D', '3d')
algorithms = (tuple(sorted(list(set(cudnn.cudnnConvolutionFwdAlgo_t.get_aliases() + algorithms = (tuple(sorted(list(set(cudnn.cudnnConvolutionFwdAlgo_t.get_aliases() +
cudnn.cudnnConvolutionBwdFilterAlgo_t.get_aliases() + cudnn.cudnnConvolutionBwdFilterAlgo_t.get_aliases() +
cudnn.cudnnConvolutionBwdDataAlgo_t.get_aliases())))) + cudnn.cudnnConvolutionBwdDataAlgo_t.get_aliases())))) +
...@@ -48,8 +47,6 @@ parser = argparse.ArgumentParser() ...@@ -48,8 +47,6 @@ parser = argparse.ArgumentParser()
parser.add_argument('computation', choices=computations, parser.add_argument('computation', choices=computations,
help='Computation to run.') help='Computation to run.')
parser.add_argument('ndim', choices=dimensions,
help='Number of dimensions ("2D" or "3D", case insensitive).')
parser.add_argument('-a', '--algo', choices=algorithms, required=True, parser.add_argument('-a', '--algo', choices=algorithms, required=True,
help='Algorithm to use for computation.') help='Algorithm to use for computation.')
...@@ -88,7 +85,11 @@ parser.add_argument('-I', '--print-infos', action='store_true', default=False, ...@@ -88,7 +85,11 @@ parser.add_argument('-I', '--print-infos', action='store_true', default=False,
args = parser.parse_args(args) args = parser.parse_args(args)
test = args.computation test = args.computation
ndim = int(args.ndim[0]) if len(args.input_shape) != len(args.filter_shape):
raise ValueError('Expected same length for input shape and filter shape')
if len(args.input_shape) not in (4, 5):
raise ValueError('Expected length 4 or 5 for input shape')
ndim = len(args.input_shape) - 2
if ndim == 2: if ndim == 2:
tests = TestDnnConv2D() tests = TestDnnConv2D()
if ndim == 3: if ndim == 3:
...@@ -97,8 +98,7 @@ if args.subsample is None: ...@@ -97,8 +98,7 @@ if args.subsample is None:
args.subsample = (1,) * ndim args.subsample = (1,) * ndim
if args.dilation is None: if args.dilation is None:
args.dilation = (1,) * ndim args.dilation = (1,) * ndim
if not (ndim == len(args.input_shape[2:]) == len(args.filter_shape[2:]) == len(args.subsample) == len( if not (ndim == len(args.subsample) == len(args.dilation)):
args.dilation)):
raise ValueError('Expected parameters sized for %d dimensions.' % ndim) raise ValueError('Expected parameters sized for %d dimensions.' % ndim)
if isinstance(args.border_mode, tuple) and ndim != len(args.border_mode): if isinstance(args.border_mode, tuple) and ndim != len(args.border_mode):
raise ValueError('Expected borders sized for %d dimensions.' % ndim) raise ValueError('Expected borders sized for %d dimensions.' % ndim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论