提交 1ac12452 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6070 from gvtulder/f-batchnorm-gradgrad

Implement grad for AbstractBatchNormTrainGrad
......@@ -597,6 +597,62 @@ class AbstractBatchNormTrainGrad(Op):
return Apply(self, [x, dy, scale, x_mean, x_invstd, epsilon],
[x.type(), scale.type(), scale.type()])
def grad(self, inp, grads):
x, dy, scale, x_mean, x_invstd, epsilon = inp
ddinputs, ddscale, ddbias = grads
x_diff = x - x_mean
mean_dy_x_diff = T.mean(dy * x_diff, axis=self.axes, keepdims=True)
# compute gradients given each of the output gradients
g_wrt_x = 0
g_wrt_dy = 0
g_wrt_scale = 0
g_wrt_x_mean = 0
g_wrt_x_invstd = 0
if not isinstance(ddinputs.type, theano.gradient.DisconnectedType):
ccc = scale * (ddinputs - T.mean(ddinputs, axis=self.axes, keepdims=True))
ddd = (x_invstd ** 3) * (ccc * T.mean(dy * x_diff, axis=self.axes, keepdims=True) +
dy * T.mean(ccc * x_diff, axis=self.axes, keepdims=True))
g_wrt_x = g_wrt_x - ddd
g_wrt_dy = g_wrt_dy + ((ccc * x_invstd) -
((x_invstd ** 3) * x_diff *
T.mean(ccc * x_diff, axis=self.axes, keepdims=True)))
eee = (dy * x_invstd) - ((x_invstd ** 3) * x_diff * mean_dy_x_diff)
g_wrt_scale = g_wrt_scale + T.sum(ddinputs * (eee - T.mean(eee, axis=self.axes, keepdims=True)),
axis=self.axes, keepdims=True)
g_wrt_x_mean = g_wrt_x_mean + T.sum(ddd, axis=self.axes, keepdims=True)
g_wrt_x_invstd = g_wrt_x_invstd + T.sum(ccc * (dy - 3 * (x_invstd ** 2) * x_diff * mean_dy_x_diff),
axis=self.axes, keepdims=True)
if not isinstance(ddscale.type, theano.gradient.DisconnectedType):
g_wrt_x = g_wrt_x + (x_invstd * ddscale * dy)
g_wrt_dy = g_wrt_dy + (x_invstd * ddscale * x_diff)
g_wrt_x_mean = g_wrt_x_mean - (x_invstd * ddscale * T.sum(dy, axis=self.axes, keepdims=True))
g_wrt_x_invstd = g_wrt_x_invstd + (ddscale * T.sum(dy * x_diff, axis=self.axes, keepdims=True))
if not isinstance(ddbias.type, theano.gradient.DisconnectedType):
g_wrt_dy = g_wrt_dy + T.fill(dy, ddbias)
# depending on which output gradients are given,
# some inputs should be disconnected
results = [g_wrt_x, g_wrt_dy, g_wrt_scale, g_wrt_x_mean, g_wrt_x_invstd,
theano.gradient.DisconnectedType()()]
return [theano.gradient.DisconnectedType()() if r is 0 else r
for r in results]
def connection_pattern(self, node):
return [[True, True, False], # x
[True, True, True], # dy
[True, False, False], # scale
[True, True, False], # x_mean
[True, True, False], # x_invstd
[False, False, False]] # epsilon
def infer_shape(self, node, shape):
return [shape[0], shape[2], shape[2]]
......
......@@ -3,6 +3,7 @@ import theano
import theano.tensor as T
from theano.tests import unittest_tools as utt
import numpy as np
from collections import OrderedDict
from theano.tensor.nnet import bn
......@@ -190,11 +191,24 @@ def test_batch_normalization_train():
grads = T.grad(None, wrt=[x, scale, bias], known_grads={out: dy})
# reference backward pass
grads2 = T.grad(None, wrt=[x, scale, bias], known_grads={out2: dy})
# second-order backward pass
dx = vartype('dinputs')
dscale = vartype('dscale')
dbias = vartype('dbias')
grad_grads = T.grad(None, wrt=[x, dy, scale], known_grads=OrderedDict(
{grads[0]: dx, grads[1]: dscale, grads[2]: dbias}),
consider_constant=[x, dy, scale, bias, x_mean, x_invstd, running_mean, running_var],
return_disconnected='zero')
# reference second-order backward pass
grad_grads2 = T.grad(None, wrt=[x, dy, scale], known_grads=OrderedDict(
{grads2[0]: dx, grads2[1]: dscale, grads2[2]: dbias}),
consider_constant=[x, dy, scale, bias, x_mean2, x_var2, running_mean, running_var],
return_disconnected='zero')
# compile
f = theano.function([x, scale, bias, running_mean, running_var, dy],
f = theano.function([x, scale, bias, running_mean, running_var, dy, dx, dscale, dbias],
[out, x_mean, x_invstd, out_running_mean, out_running_var,
out2, x_mean2, x_invstd2, out_running_mean2, out_running_var2] +
grads + grads2)
grads + grads2 + grad_grads + grad_grads2)
# check if the abstract Ops have been replaced
assert not any([isinstance(n.op, (bn.AbstractBatchNormTrain,
bn.AbstractBatchNormInference,
......@@ -211,7 +225,11 @@ def test_batch_normalization_train():
Bias = np.random.randn(*param_shape).astype(theano.config.floatX)
Running_mean = np.random.randn(*param_shape).astype(theano.config.floatX)
Running_var = np.random.randn(*param_shape).astype(theano.config.floatX)
outputs = f(X, Scale, Bias, Running_mean, Running_var, Dy)
Dx = 4 + 3 * np.random.randn(*data_shape).astype(theano.config.floatX)
Dscale = -1 + 2 * np.random.randn(*param_shape).astype(theano.config.floatX)
Dbias = np.random.randn(*param_shape).astype(theano.config.floatX)
outputs = f(X, Scale, Bias, Running_mean, Running_var, Dy, Dx, Dscale, Dbias)
# compare outputs
utt.assert_allclose(outputs[0], outputs[0 + 5]) # out
utt.assert_allclose(outputs[1], outputs[1 + 5]) # mean
......@@ -223,6 +241,61 @@ def test_batch_normalization_train():
utt.assert_allclose(outputs[10], outputs[10 + 3], atol=1e-4) # dx
utt.assert_allclose(outputs[11], outputs[11 + 3], rtol=2e-4, atol=1e-4) # dscale
utt.assert_allclose(outputs[12], outputs[12 + 3]) # dbias
# compare second-order gradients
utt.assert_allclose(outputs[16], outputs[16 + 3], atol=1e-4) # ddx
utt.assert_allclose(outputs[17], outputs[17 + 3]) # ddy
utt.assert_allclose(outputs[18], outputs[18 + 3], rtol=3e-4, atol=1e-4) # ddscale
def test_batch_normalization_train_grad_grad():
utt.seed_rng()
for axes in ('per-activation', 'spatial', (1, 2, 3, 4)):
for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix, T.vector):
# run these experiments with float64 for sufficient numerical stability
x, dy, scale, x_mean, x_invstd = (vartype(n, dtype='float64')
for n in ('x', 'dy', 'scale',
'x_mean', 'x_invstd'))
ndim = x.ndim
# reference forward pass
if axes == 'per-activation':
axes = (0,)
elif axes == 'spatial':
axes = (0,) + tuple(range(2, ndim))
else:
# remove non-existing axes
axes = tuple(i for i in axes if i < ndim)
if len(axes) == 0:
continue
def bn_grad_wrt_inputs_f(x, dy, scale, x_mean, x_invstd):
g_inputs, g_scale, g_bias = bn.AbstractBatchNormTrainGrad(axes)(x, dy, scale, x_mean, x_invstd)
return g_inputs
def bn_grad_wrt_scale_f(x, dy, scale, x_mean, x_invstd):
g_inputs, g_scale, g_bias = bn.AbstractBatchNormTrainGrad(axes)(x, dy, scale, x_mean, x_invstd)
return g_scale
def bn_grad_wrt_bias_f(x, dy, scale, x_mean, x_invstd):
g_inputs, g_scale, g_bias = bn.AbstractBatchNormTrainGrad(axes)(x, dy, scale, x_mean, x_invstd)
return g_bias
# run
for data_shape in ((4, 3, 3, 3, 3), (4, 3, 1, 1, 1), (2, 3, 5, 3, 2)):
data_shape = data_shape[:ndim]
param_shape = tuple(1 if d in axes else s
for d, s in enumerate(data_shape))
# force float64 for sufficient numerical stability
x_val = 4 + 3 * np.random.randn(*data_shape).astype('float64')
dy_val = -1 + 2 * np.random.randn(*data_shape).astype('float64')
scale_val = np.random.randn(*param_shape).astype('float64')
x_mean_val = np.random.randn(*param_shape).astype('float64')
x_invstd_val = np.random.randn(*param_shape).astype('float64')
utt.verify_grad(bn_grad_wrt_inputs_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
utt.verify_grad(bn_grad_wrt_scale_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
utt.verify_grad(bn_grad_wrt_bias_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
def test_batch_normalization_train_without_running_averages():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论