提交 6e0a10e9 authored 作者: notoraptor's avatar notoraptor

Fix printed info in `run_dnn_conv.py`.

上级 7272a2e0
...@@ -141,10 +141,13 @@ print('======================') ...@@ -141,10 +141,13 @@ print('======================')
print('Running', test, algo, dtype, precision, *parameters) print('Running', test, algo, dtype, precision, *parameters)
if test == FWD: if test == FWD:
tests.run_conv_fwd(algo, dtype, precision, parameters) tests.run_conv_fwd(algo, dtype, precision, parameters)
expected_output_shape = get_conv_output_shape(args.input_shape, args.filter_shape, args.border_mode,
args.subsample, args.dilation)
if test == BWD_FILTER: if test == BWD_FILTER:
tests.run_conv_gradweight(algo, dtype, precision, parameters) tests.run_conv_gradweight(algo, dtype, precision, parameters)
expected_output_shape = args.filter_shape
if test == BWD_DATA: if test == BWD_DATA:
tests.run_conv_gradinput(algo, dtype, precision, parameters) tests.run_conv_gradinput(algo, dtype, precision, parameters)
print('Output shape:', get_conv_output_shape(args.input_shape, args.filter_shape, args.border_mode, expected_output_shape = args.input_shape
args.subsample, args.dilation)) print('Computed shape:', expected_output_shape)
print('... OK') print('... OK')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论