提交 0cb8fbe7 authored 作者: notoraptor's avatar notoraptor

Fix typos and simplify tests.

上级 0cbfeae8
......@@ -1067,15 +1067,19 @@ def get_conv3d_test_cases():
return itt
def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub, subsample):
def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub):
# Function to check issue #5985 (see tests below): https://github.com/Theano/Theano/issues/5985
# Error occurs with algorithm `small` (CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM)
algo = 'small'
batch_size = inputs_shape[0]
utt.seed_rng()
inputs_val = np.random.random(inputs_shape).astype('float32')
filters_val = np.random.random(filters_shape).astype('float32')
# Scale down the input values to prevent very large absolute errors
# due to float rounding
inputs_val /= 10
filters_val /= 10
inputs = theano.shared(inputs_val)
......@@ -1085,19 +1089,18 @@ def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub,
dnn_func = dnn.dnn_conv3d
else:
dnn_func = dnn.dnn_conv
conv = dnn_func(img=inputs, kerns=filters, algo=algo, subsample=subsample)
# Just compute firt and last outputs to reduce execution time.
sub_conv_top = dnn_func(img=inputs[:batch_sub],
kerns=filters, algo=algo, subsample=subsample)
sub_conv_bottom = dnn_func(img=inputs[(batch_size - batch_sub):],
kerns=filters, algo=algo, subsample=subsample)
conv = dnn_func(img=inputs, kerns=filters, algo=algo)
# Just compute first and last outputs, to reduce execution time.
sub_conv_top = dnn_func(img=inputs[:batch_sub], kerns=filters, algo=algo)
sub_conv_bottom = dnn_func(img=inputs[(batch_size - batch_sub):], kerns=filters, algo=algo)
f = theano.function([], [conv, sub_conv_top, sub_conv_bottom], mode=mode_with_gpu)
res_all, res_batch_top, res_batch_bottom = f()
for i in range(0, batch_sub):
for i in range(batch_sub):
# Check first ouputs.
utt.assert_allclose(res_batch_top[i], res_all[i])
# Then check last outputs.
p = batch_size - batch_sub + i
# It seems there is a liimit batch size of 65536 for a good computation
# with algorithm `small`.
# It seems there is a limit batch size of 65536 with algorithm `small`.
checked_limit = 2**16
if p >= checked_limit:
# It seems results are repeated in the entire conv.
......@@ -1108,17 +1111,17 @@ def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub,
def test_batched_conv_small():
yield (run_conv_small_batched_vs_multicall, (65534, 2, 2, 2), (1, 2, 2, 2), 5, (1, 1)) # OK
yield (run_conv_small_batched_vs_multicall, (65535, 2, 2, 2), (1, 2, 2, 2), 5, (1, 1)) # OK
yield (run_conv_small_batched_vs_multicall, (65536, 2, 2, 2), (1, 2, 2, 2), 5, (1, 1)) # OK
yield (run_conv_small_batched_vs_multicall, (65537, 2, 2, 2), (1, 2, 2, 2), 5, (1, 1)) # ERROR
yield (run_conv_small_batched_vs_multicall, (65534, 2, 2, 2), (1, 2, 2, 2), 5) # OK
yield (run_conv_small_batched_vs_multicall, (65535, 2, 2, 2), (1, 2, 2, 2), 5) # OK
yield (run_conv_small_batched_vs_multicall, (65536, 2, 2, 2), (1, 2, 2, 2), 5) # OK
yield (run_conv_small_batched_vs_multicall, (65537, 2, 2, 2), (1, 2, 2, 2), 5) # ERROR
def test_batched_conv3d_small():
yield (run_conv_small_batched_vs_multicall, (65534, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5, (1, 1, 1)) # OK
yield (run_conv_small_batched_vs_multicall, (65535, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5, (1, 1, 1)) # OK
yield (run_conv_small_batched_vs_multicall, (65536, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5, (1, 1, 1)) # OK
yield (run_conv_small_batched_vs_multicall, (65537, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5, (1, 1, 1)) # ERROR ALSO.
yield (run_conv_small_batched_vs_multicall, (65534, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5) # OK
yield (run_conv_small_batched_vs_multicall, (65535, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5) # OK
yield (run_conv_small_batched_vs_multicall, (65536, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5) # OK
yield (run_conv_small_batched_vs_multicall, (65537, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5) # ERROR ALSO.
def test_conv3d_fwd():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论