提交 a66b878f authored 作者: notoraptor's avatar notoraptor

Add cuDNN conv test to reproduce a problem related to

runtime algorithms and different data type configurations.
上级 078bdfb1
...@@ -2666,3 +2666,31 @@ class TestDnnConv3DRuntimeAlgorithms(TestDnnConv2DRuntimeAlgorithms): ...@@ -2666,3 +2666,31 @@ class TestDnnConv3DRuntimeAlgorithms(TestDnnConv2DRuntimeAlgorithms):
(1, [(4, 2, 20, 20, 20), (2, 2, 20, 19, 18)]), # cache should be used (1, [(4, 2, 20, 20, 20), (2, 2, 20, 19, 18)]), # cache should be used
(1, [(1, 2, 3, 4, 5), (6, 2, 3, 2, 1)]) (1, [(1, 2, 3, 4, 5), (6, 2, 3, 2, 1)])
] ]
def test_conv_guess_once_with_dtypes():
utt.seed_rng()
inputs_shape = (2, 3, 5, 5)
filters_shape = (2, 3, 40, 4)
border_mode = 'full'
def get_function(dtype, precision):
inputs_val = np.random.random(inputs_shape).astype(dtype)
filters_val = np.random.random(filters_shape).astype(dtype)
inputs_val /= 10
filters_val /= 10
inputs = theano.shared(inputs_val)
filters = theano.shared(filters_val)
conv = dnn.dnn_conv(img=inputs, kerns=filters, border_mode=border_mode, precision=precision,
algo='guess_once', direction_hint='forward!')
return theano.function([], conv)
f_true_half_config = get_function('float16', 'float16')
f_pseudo_half_config = get_function('float16', 'float32')
f_float_config = get_function('float32', 'float32')
f_double_config = get_function('float64', 'float64')
# Let's just see if everything runs without raising any exception.
f_true_half_config()
f_pseudo_half_config()
f_float_config()
f_double_config()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论