提交 a74c1b06 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

add batch normalization high-mem implementation

上级 ff7eb105
...@@ -26,13 +26,11 @@ class BNComposite(Composite): ...@@ -26,13 +26,11 @@ class BNComposite(Composite):
return [dx, dmean, dstd, dgamma, top] return [dx, dmean, dstd, dgamma, top]
def batch_normalization(inputs, gamma, beta, mean, std): def batch_normalization(inputs, gamma, beta, mean, std,
mode='low_mem'):
""" """
This function will build the symbolic graph for applying batch normalization This function will build the symbolic graph for applying batch normalization
to a set of activations. As no intermediate representations are stored for the to a set of activations.
back-propagation, this implementation lower the memory usage, however,
it is 5-10% slower than a naive theano implementation, as it redo
some forward computations for the backprop.
Work also on GPU Work also on GPU
Parameters Parameters
...@@ -51,7 +49,20 @@ def batch_normalization(inputs, gamma, beta, mean, std): ...@@ -51,7 +49,20 @@ def batch_normalization(inputs, gamma, beta, mean, std):
std: symbolic tensor std: symbolic tensor
inputs standard deviation, must be of same dimensionality as inputs standard deviation, must be of same dimensionality as
inputs and broadcastable against it inputs and broadcastable against it
mode: 'low_mem' or 'high_mem'
Specify which batch_normalization implementation that will be
used.
As no intermediate representations are stored for the
back-propagation, 'low_mem' implementation lower the memory usage, however,
it is 5-10% slower than 'high_mem' implementation.
""" """
elm_bn = theano.tensor.elemwise.Elemwise(scalar_op=BNComposite(dtype=inputs.dtype)) if mode == 'low_mem':
rval = elm_bn(inputs, mean, std, gamma, beta) elm_bn = theano.tensor.elemwise.Elemwise(scalar_op=BNComposite(dtype=inputs.dtype))
rval = elm_bn(inputs, mean, std, gamma, beta)
elif mode == 'high_mem':
rval = (inputs - mean) / std
rval = rval * gamma + beta
else:
raise ValueError(
'mode must be either "low_mem", "high_mem"')
return rval return rval
...@@ -24,24 +24,32 @@ def test_bn(): ...@@ -24,24 +24,32 @@ def test_bn():
m = theano.tensor.vector('m') m = theano.tensor.vector('m')
v = theano.tensor.vector('v') v = theano.tensor.vector('v')
bn_op = batch_normalization(x, g, b, m, v)
bn_ref_op = bn_ref(x, g, b, m, v) bn_ref_op = bn_ref(x, g, b, m, v)
f = theano.function([x, b, g, m, v], [bn_op])
f_ref = theano.function([x, b, g, m, v], [bn_ref_op]) f_ref = theano.function([x, b, g, m, v], [bn_ref_op])
res = f(X, G, B, M, V)
res_ref = f_ref(X, G, B, M, V) res_ref = f_ref(X, G, B, M, V)
utt.assert_allclose(res_ref, res) for mode in ['low_mem', 'high_mem']:
utt.verify_grad(batch_normalization, [X, G, B, M, V]) bn_op = batch_normalization(x, g, b, m, v, mode=mode)
f = theano.function([x, b, g, m, v], [bn_op])
res = f(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
def bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs, gamma, beta, mean, std, mode=mode)
utt.verify_grad(bn, [X, G, B, M, V])
bn_op = batch_normalization(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True))
bn_ref_op = bn_ref(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True)) bn_ref_op = bn_ref(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True))
f = theano.function([x, b, g], [bn_op])
f_ref = theano.function([x, b, g], [bn_ref_op]) f_ref = theano.function([x, b, g], [bn_ref_op])
res = f(X, G, B)
res_ref = f_ref(X, G, B) res_ref = f_ref(X, G, B)
utt.assert_allclose(res_ref, res) for mode in ['low_mem', 'high_mem']:
utt.verify_grad(batch_normalization, [X, G, B, bn_op = batch_normalization(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True))
X.mean(axis=0)[numpy.newaxis], X.std(axis=0)[numpy.newaxis]]) f = theano.function([x, b, g], [bn_op])
res = f(X, G, B)
utt.assert_allclose(res_ref, res)
def bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs, gamma, beta, mean, std, mode=mode)
utt.verify_grad(batch_normalization, [X, G, B,
X.mean(axis=0)[numpy.newaxis], X.std(axis=0)[numpy.newaxis]])
def test_bn_feature_maps(): def test_bn_feature_maps():
...@@ -63,26 +71,30 @@ def test_bn_feature_maps(): ...@@ -63,26 +71,30 @@ def test_bn_feature_maps():
m = theano.tensor.vector('m') m = theano.tensor.vector('m')
v = theano.tensor.vector('v') v = theano.tensor.vector('v')
bn_op = batch_normalization(x,
g.dimshuffle('x', 0, 'x', 'x'),
b.dimshuffle('x', 0, 'x', 'x'),
m.dimshuffle('x', 0, 'x', 'x'),
v.dimshuffle('x', 0, 'x', 'x'))
bn_ref_op = bn_ref(x, bn_ref_op = bn_ref(x,
g.dimshuffle('x', 0, 'x', 'x'), g.dimshuffle('x', 0, 'x', 'x'),
b.dimshuffle('x', 0, 'x', 'x'), b.dimshuffle('x', 0, 'x', 'x'),
m.dimshuffle('x', 0, 'x', 'x'), m.dimshuffle('x', 0, 'x', 'x'),
v.dimshuffle('x', 0, 'x', 'x')) v.dimshuffle('x', 0, 'x', 'x'))
f = theano.function([x, b, g, m, v], [bn_op])
f_ref = theano.function([x, b, g, m, v], [bn_ref_op]) f_ref = theano.function([x, b, g, m, v], [bn_ref_op])
res = f(X, G, B, M, V)
res_ref = f_ref(X, G, B, M, V) res_ref = f_ref(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
for mode in ['low_mem', 'high_mem']:
def conv_bn(inputs, gamma, beta, mean, std): bn_op = batch_normalization(x,
return batch_normalization(inputs, g.dimshuffle('x', 0, 'x', 'x'),
gamma.dimshuffle('x', 0, 'x', 'x'), b.dimshuffle('x', 0, 'x', 'x'),
beta.dimshuffle('x', 0, 'x', 'x'), m.dimshuffle('x', 0, 'x', 'x'),
mean.dimshuffle('x', 0, 'x', 'x'), v.dimshuffle('x', 0, 'x', 'x'),
std.dimshuffle('x', 0, 'x', 'x')) mode=mode)
utt.verify_grad(conv_bn, [X, G, B, M, V]) f = theano.function([x, b, g, m, v], [bn_op])
res = f(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
def conv_bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs,
gamma.dimshuffle('x', 0, 'x', 'x'),
beta.dimshuffle('x', 0, 'x', 'x'),
mean.dimshuffle('x', 0, 'x', 'x'),
std.dimshuffle('x', 0, 'x', 'x'),
mode=mode)
utt.verify_grad(conv_bn, [X, G, B, M, V])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论