提交 b6747150 authored 作者: slefrancois's avatar slefrancois

use float32 reference and raise rtol for test_conv3d in float16

上级 d5a45459
......@@ -33,6 +33,13 @@ def set_precision(floatX):
return precision
# If using float16, cast reference input to float32
def ref_cast(x):
if theano.config.floatX == 'float16':
x = T.cast(x, 'float32')
return x
def test_dnn_conv_desc_merge():
if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg)
......@@ -1044,13 +1051,18 @@ def test_conv3d_fwd():
# Compile a theano function for the reference implementation
conv_ref = theano.tensor.nnet.corr3d.Corr3dMM(border_mode=border_mode,
subsample=subsample
)(inputs, flipped_filters)
)(ref_cast(inputs), flipped_filters)
f_ref = theano.function([], conv_ref, mode="FAST_RUN")
# Compare the results of the two implementations
res_ref = f_ref()
res = f()
utt.assert_allclose(res_ref, res)
# raise rtol to make the test pass with more seed.
rtol = None
# Raise tolerance for float16
if theano.config.floatX == 'float16':
rtol = 6e-2
utt.assert_allclose(res_ref, res, rtol=rtol)
test_cases = get_conv3d_test_cases()
for (i_shape, f_shape, subsample), border_mode, conv_mode in test_cases:
......@@ -1091,7 +1103,7 @@ def test_conv3d_bwd():
# Compile a theano function for the reference implementation
conv_ref = theano.tensor.nnet.corr3d.Corr3dMM(border_mode=border_mode,
subsample=subsample
)(inputs, flipped_filters)
)(ref_cast(inputs), flipped_filters)
(grad_i_ref,
grad_w_ref) = theano.tensor.grad(conv_ref.sum(),
[inputs, filters])
......@@ -1102,8 +1114,12 @@ def test_conv3d_bwd():
res = f()
# Needed for big size for some seed
# raise rtol to make the test pass with more seed.
utt.assert_allclose(res_ref[0], res[0], rtol=2e-5)
utt.assert_allclose(res_ref[1], res[1], rtol=2e-5)
rtol = None
# Raise tolerance for float16
if theano.config.floatX == 'float16':
rtol = 5e-2
utt.assert_allclose(res_ref[0], res[0], rtol=rtol)
utt.assert_allclose(res_ref[1], res[1], rtol=rtol)
test_cases = get_conv3d_test_cases()
for (i_shape, f_shape, subsample), border_mode, conv_mode in test_cases:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论