提交 2aff6d77 authored 作者: notoraptor's avatar notoraptor

Simplify script `run_dnn_conv.py`.

上级 8ee47394
......@@ -35,7 +35,6 @@ if __name__ != '__main__':
args = sys.argv[1:]
computations = FWD, BWD_FILTER, BWD_DATA = ('fwd', 'gradweight', 'gradinput')
dimensions = ('2D', '2d', '3D', '3d')
algorithms = (tuple(sorted(list(set(cudnn.cudnnConvolutionFwdAlgo_t.get_aliases() +
cudnn.cudnnConvolutionBwdFilterAlgo_t.get_aliases() +
cudnn.cudnnConvolutionBwdDataAlgo_t.get_aliases())))) +
......@@ -48,8 +47,6 @@ parser = argparse.ArgumentParser()
parser.add_argument('computation', choices=computations,
help='Computation to run.')
parser.add_argument('ndim', choices=dimensions,
help='Number of dimensions ("2D" or "3D", case insensitive).')
parser.add_argument('-a', '--algo', choices=algorithms, required=True,
help='Algorithm to use for computation.')
......@@ -88,7 +85,11 @@ parser.add_argument('-I', '--print-infos', action='store_true', default=False,
args = parser.parse_args(args)
test = args.computation
ndim = int(args.ndim[0])
if len(args.input_shape) != len(args.filter_shape):
raise ValueError('Expected same length for input shape and filter shape')
if len(args.input_shape) not in (4, 5):
raise ValueError('Expected length 4 or 5 for input shape')
ndim = len(args.input_shape) - 2
if ndim == 2:
tests = TestDnnConv2D()
if ndim == 3:
......@@ -97,8 +98,7 @@ if args.subsample is None:
args.subsample = (1,) * ndim
if args.dilation is None:
args.dilation = (1,) * ndim
if not (ndim == len(args.input_shape[2:]) == len(args.filter_shape[2:]) == len(args.subsample) == len(
args.dilation)):
if not (ndim == len(args.subsample) == len(args.dilation)):
raise ValueError('Expected parameters sized for %d dimensions.' % ndim)
if isinstance(args.border_mode, tuple) and ndim != len(args.border_mode):
raise ValueError('Expected borders sized for %d dimensions.' % ndim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论