提交 b6747150 authored 作者: slefrancois's avatar slefrancois

use float32 reference and raise rtol for test_conv3d in float16

上级 d5a45459
...@@ -33,6 +33,13 @@ def set_precision(floatX): ...@@ -33,6 +33,13 @@ def set_precision(floatX):
return precision return precision
# If using float16, cast reference input to float32
def ref_cast(x):
if theano.config.floatX == 'float16':
x = T.cast(x, 'float32')
return x
def test_dnn_conv_desc_merge(): def test_dnn_conv_desc_merge():
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
...@@ -1044,13 +1051,18 @@ def test_conv3d_fwd(): ...@@ -1044,13 +1051,18 @@ def test_conv3d_fwd():
# Compile a theano function for the reference implementation # Compile a theano function for the reference implementation
conv_ref = theano.tensor.nnet.corr3d.Corr3dMM(border_mode=border_mode, conv_ref = theano.tensor.nnet.corr3d.Corr3dMM(border_mode=border_mode,
subsample=subsample subsample=subsample
)(inputs, flipped_filters) )(ref_cast(inputs), flipped_filters)
f_ref = theano.function([], conv_ref, mode="FAST_RUN") f_ref = theano.function([], conv_ref, mode="FAST_RUN")
# Compare the results of the two implementations # Compare the results of the two implementations
res_ref = f_ref() res_ref = f_ref()
res = f() res = f()
utt.assert_allclose(res_ref, res) # raise rtol to make the test pass with more seed.
rtol = None
# Raise tolerance for float16
if theano.config.floatX == 'float16':
rtol = 6e-2
utt.assert_allclose(res_ref, res, rtol=rtol)
test_cases = get_conv3d_test_cases() test_cases = get_conv3d_test_cases()
for (i_shape, f_shape, subsample), border_mode, conv_mode in test_cases: for (i_shape, f_shape, subsample), border_mode, conv_mode in test_cases:
...@@ -1091,7 +1103,7 @@ def test_conv3d_bwd(): ...@@ -1091,7 +1103,7 @@ def test_conv3d_bwd():
# Compile a theano function for the reference implementation # Compile a theano function for the reference implementation
conv_ref = theano.tensor.nnet.corr3d.Corr3dMM(border_mode=border_mode, conv_ref = theano.tensor.nnet.corr3d.Corr3dMM(border_mode=border_mode,
subsample=subsample subsample=subsample
)(inputs, flipped_filters) )(ref_cast(inputs), flipped_filters)
(grad_i_ref, (grad_i_ref,
grad_w_ref) = theano.tensor.grad(conv_ref.sum(), grad_w_ref) = theano.tensor.grad(conv_ref.sum(),
[inputs, filters]) [inputs, filters])
...@@ -1102,8 +1114,12 @@ def test_conv3d_bwd(): ...@@ -1102,8 +1114,12 @@ def test_conv3d_bwd():
res = f() res = f()
# Needed for big size for some seed # Needed for big size for some seed
# raise rtol to make the test pass with more seed. # raise rtol to make the test pass with more seed.
utt.assert_allclose(res_ref[0], res[0], rtol=2e-5) rtol = None
utt.assert_allclose(res_ref[1], res[1], rtol=2e-5) # Raise tolerance for float16
if theano.config.floatX == 'float16':
rtol = 5e-2
utt.assert_allclose(res_ref[0], res[0], rtol=rtol)
utt.assert_allclose(res_ref[1], res[1], rtol=rtol)
test_cases = get_conv3d_test_cases() test_cases = get_conv3d_test_cases()
for (i_shape, f_shape, subsample), border_mode, conv_mode in test_cases: for (i_shape, f_shape, subsample), border_mode, conv_mode in test_cases:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论