提交 a74c1b06 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

add batch normalization high-mem implementation

上级 ff7eb105
......@@ -26,13 +26,11 @@ class BNComposite(Composite):
return [dx, dmean, dstd, dgamma, top]
def batch_normalization(inputs, gamma, beta, mean, std):
def batch_normalization(inputs, gamma, beta, mean, std,
mode='low_mem'):
"""
This function will build the symbolic graph for applying batch normalization
to a set of activations. As no intermediate representations are stored for the
back-propagation, this implementation lower the memory usage, however,
it is 5-10% slower than a naive theano implementation, as it redo
some forward computations for the backprop.
to a set of activations.
Work also on GPU
Parameters
......@@ -51,7 +49,20 @@ def batch_normalization(inputs, gamma, beta, mean, std):
std: symbolic tensor
inputs standard deviation, must be of same dimensionality as
inputs and broadcastable against it
mode: 'low_mem' or 'high_mem'
Specify which batch_normalization implementation that will be
used.
As no intermediate representations are stored for the
back-propagation, 'low_mem' implementation lower the memory usage, however,
it is 5-10% slower than 'high_mem' implementation.
"""
elm_bn = theano.tensor.elemwise.Elemwise(scalar_op=BNComposite(dtype=inputs.dtype))
rval = elm_bn(inputs, mean, std, gamma, beta)
if mode == 'low_mem':
elm_bn = theano.tensor.elemwise.Elemwise(scalar_op=BNComposite(dtype=inputs.dtype))
rval = elm_bn(inputs, mean, std, gamma, beta)
elif mode == 'high_mem':
rval = (inputs - mean) / std
rval = rval * gamma + beta
else:
raise ValueError(
'mode must be either "low_mem", "high_mem"')
return rval
......@@ -24,24 +24,32 @@ def test_bn():
m = theano.tensor.vector('m')
v = theano.tensor.vector('v')
bn_op = batch_normalization(x, g, b, m, v)
bn_ref_op = bn_ref(x, g, b, m, v)
f = theano.function([x, b, g, m, v], [bn_op])
f_ref = theano.function([x, b, g, m, v], [bn_ref_op])
res = f(X, G, B, M, V)
res_ref = f_ref(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
utt.verify_grad(batch_normalization, [X, G, B, M, V])
for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x, g, b, m, v, mode=mode)
f = theano.function([x, b, g, m, v], [bn_op])
res = f(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
def bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs, gamma, beta, mean, std, mode=mode)
utt.verify_grad(bn, [X, G, B, M, V])
bn_op = batch_normalization(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True))
bn_ref_op = bn_ref(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True))
f = theano.function([x, b, g], [bn_op])
f_ref = theano.function([x, b, g], [bn_ref_op])
res = f(X, G, B)
res_ref = f_ref(X, G, B)
utt.assert_allclose(res_ref, res)
utt.verify_grad(batch_normalization, [X, G, B,
X.mean(axis=0)[numpy.newaxis], X.std(axis=0)[numpy.newaxis]])
for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True))
f = theano.function([x, b, g], [bn_op])
res = f(X, G, B)
utt.assert_allclose(res_ref, res)
def bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs, gamma, beta, mean, std, mode=mode)
utt.verify_grad(batch_normalization, [X, G, B,
X.mean(axis=0)[numpy.newaxis], X.std(axis=0)[numpy.newaxis]])
def test_bn_feature_maps():
......@@ -63,26 +71,30 @@ def test_bn_feature_maps():
m = theano.tensor.vector('m')
v = theano.tensor.vector('v')
bn_op = batch_normalization(x,
g.dimshuffle('x', 0, 'x', 'x'),
b.dimshuffle('x', 0, 'x', 'x'),
m.dimshuffle('x', 0, 'x', 'x'),
v.dimshuffle('x', 0, 'x', 'x'))
bn_ref_op = bn_ref(x,
g.dimshuffle('x', 0, 'x', 'x'),
b.dimshuffle('x', 0, 'x', 'x'),
m.dimshuffle('x', 0, 'x', 'x'),
v.dimshuffle('x', 0, 'x', 'x'))
f = theano.function([x, b, g, m, v], [bn_op])
f_ref = theano.function([x, b, g, m, v], [bn_ref_op])
res = f(X, G, B, M, V)
res_ref = f_ref(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
def conv_bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs,
gamma.dimshuffle('x', 0, 'x', 'x'),
beta.dimshuffle('x', 0, 'x', 'x'),
mean.dimshuffle('x', 0, 'x', 'x'),
std.dimshuffle('x', 0, 'x', 'x'))
utt.verify_grad(conv_bn, [X, G, B, M, V])
for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x,
g.dimshuffle('x', 0, 'x', 'x'),
b.dimshuffle('x', 0, 'x', 'x'),
m.dimshuffle('x', 0, 'x', 'x'),
v.dimshuffle('x', 0, 'x', 'x'),
mode=mode)
f = theano.function([x, b, g, m, v], [bn_op])
res = f(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
def conv_bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs,
gamma.dimshuffle('x', 0, 'x', 'x'),
beta.dimshuffle('x', 0, 'x', 'x'),
mean.dimshuffle('x', 0, 'x', 'x'),
std.dimshuffle('x', 0, 'x', 'x'),
mode=mode)
utt.verify_grad(conv_bn, [X, G, B, M, V])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论