提交 0db7bc73 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Force batchnorm grad grad experiments to run in float64.

上级 ecd9be60
...@@ -612,7 +612,7 @@ class AbstractBatchNormTrainGrad(Op): ...@@ -612,7 +612,7 @@ class AbstractBatchNormTrainGrad(Op):
g_wrt_x_invstd = 0 g_wrt_x_invstd = 0
if not isinstance(ddinputs.type, theano.gradient.DisconnectedType): if not isinstance(ddinputs.type, theano.gradient.DisconnectedType):
ccc = (ddinputs * scale) - T.mean(ddinputs * scale, axis=self.axes, keepdims=True) ccc = scale * (ddinputs - T.mean(ddinputs, axis=self.axes, keepdims=True))
ddd = (x_invstd ** 3) * (ccc * T.mean(dy * x_diff, axis=self.axes, keepdims=True) + ddd = (x_invstd ** 3) * (ccc * T.mean(dy * x_diff, axis=self.axes, keepdims=True) +
dy * T.mean(ccc * x_diff, axis=self.axes, keepdims=True)) dy * T.mean(ccc * x_diff, axis=self.axes, keepdims=True))
......
...@@ -242,9 +242,9 @@ def test_batch_normalization_train(): ...@@ -242,9 +242,9 @@ def test_batch_normalization_train():
utt.assert_allclose(outputs[11], outputs[11 + 3], rtol=2e-4, atol=1e-4) # dscale utt.assert_allclose(outputs[11], outputs[11 + 3], rtol=2e-4, atol=1e-4) # dscale
utt.assert_allclose(outputs[12], outputs[12 + 3]) # dbias utt.assert_allclose(outputs[12], outputs[12 + 3]) # dbias
# compare second-order gradients # compare second-order gradients
utt.assert_allclose(outputs[16], outputs[16 + 3]) # ddx utt.assert_allclose(outputs[16], outputs[16 + 3], atol=1e-4) # ddx
utt.assert_allclose(outputs[17], outputs[17 + 3]) # ddy utt.assert_allclose(outputs[17], outputs[17 + 3]) # ddy
utt.assert_allclose(outputs[18], outputs[18 + 3]) # ddscale utt.assert_allclose(outputs[18], outputs[18 + 3], rtol=3e-4, atol=1e-4) # ddscale
def test_batch_normalization_train_grad_grad(): def test_batch_normalization_train_grad_grad():
...@@ -252,7 +252,8 @@ def test_batch_normalization_train_grad_grad(): ...@@ -252,7 +252,8 @@ def test_batch_normalization_train_grad_grad():
for axes in ('per-activation', 'spatial', (1, 2, 3, 4)): for axes in ('per-activation', 'spatial', (1, 2, 3, 4)):
for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix, T.vector): for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix, T.vector):
x, dy, scale, x_mean, x_invstd = (vartype(n) # run these experiments with float64 for sufficient numerical stability
x, dy, scale, x_mean, x_invstd = (vartype(n, dtype='float64')
for n in ('x', 'dy', 'scale', for n in ('x', 'dy', 'scale',
'x_mean', 'x_invstd')) 'x_mean', 'x_invstd'))
ndim = x.ndim ndim = x.ndim
...@@ -281,17 +282,18 @@ def test_batch_normalization_train_grad_grad(): ...@@ -281,17 +282,18 @@ def test_batch_normalization_train_grad_grad():
return g_bias return g_bias
# run # run
for data_shape in ((7, 9, 3, 4, 5), (4, 3, 1, 1, 1), (2, 3, 5, 5, 5)): for data_shape in ((4, 3, 3, 3, 3), (4, 3, 1, 1, 1), (2, 3, 5, 3, 2)):
data_shape = data_shape[:ndim] data_shape = data_shape[:ndim]
param_shape = tuple(1 if d in axes else s param_shape = tuple(1 if d in axes else s
for d, s in enumerate(data_shape)) for d, s in enumerate(data_shape))
x_val = 4 + 4 * np.random.randn(*data_shape).astype(theano.config.floatX) # force float64 for sufficient numerical stability
dy_val = -1 + 3 * np.random.randn(*data_shape).astype(theano.config.floatX) x_val = 4 + 3 * np.random.randn(*data_shape).astype('float64')
scale_val = np.random.randn(*param_shape).astype(theano.config.floatX) dy_val = -1 + 2 * np.random.randn(*data_shape).astype('float64')
x_mean_val = np.random.randn(*param_shape).astype(theano.config.floatX) scale_val = np.random.randn(*param_shape).astype('float64')
x_invstd_val = np.random.randn(*param_shape).astype(theano.config.floatX) x_mean_val = np.random.randn(*param_shape).astype('float64')
x_invstd_val = np.random.randn(*param_shape).astype('float64')
utt.verify_grad(bn_grad_wrt_inputs_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val], eps=1e-6, abs_tol=2e-4)
utt.verify_grad(bn_grad_wrt_inputs_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
utt.verify_grad(bn_grad_wrt_scale_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val]) utt.verify_grad(bn_grad_wrt_scale_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
utt.verify_grad(bn_grad_wrt_bias_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val]) utt.verify_grad(bn_grad_wrt_bias_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论