提交 3f57f90b authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Accept non-broadcasted inputs for batch normalization.

上级 9c90f281
...@@ -83,6 +83,24 @@ def batch_normalization(inputs, gamma, beta, mean, std, ...@@ -83,6 +83,24 @@ def batch_normalization(inputs, gamma, beta, mean, std,
return rval return rval
def _prepare_batch_normalization_axes(axes, ndim):
if axes == 'per-activation':
axes = (0,)
elif axes == 'spatial':
axes = (0,) + tuple(range(2, ndim))
elif isinstance(axes, (tuple, list, numpy.ndarray)):
axes = tuple(int(a) for a in axes)
else:
raise ValueError('invalid axes: %s', str(axes))
axes = tuple(sorted(axes))
if len(axes) == 0:
raise ValueError('there should be at least one normalization axis')
if min(axes) < 0 or max(axes) >= ndim:
raise ValueError('axes should be less than ndim (<%d), but %s given' % (ndim, str(axes)))
non_bc_axes = tuple(i for i in range(ndim) if i not in axes)
return axes, non_bc_axes
def batch_normalization_train(inputs, gamma, beta, axes='per-activation', def batch_normalization_train(inputs, gamma, beta, axes='per-activation',
epsilon=1e-4, running_average_factor=0.1, epsilon=1e-4, running_average_factor=0.1,
running_mean=None, running_var=None): running_mean=None, running_var=None):
...@@ -99,10 +117,9 @@ def batch_normalization_train(inputs, gamma, beta, axes='per-activation', ...@@ -99,10 +117,9 @@ def batch_normalization_train(inputs, gamma, beta, axes='per-activation',
(i.e., all dimensions past the second), which for 4D inputs would be (i.e., all dimensions past the second), which for 4D inputs would be
equal to ``axes=(0, 2, 3)``. equal to ``axes=(0, 2, 3)``.
gamma : tensor gamma : tensor
Learnable scale factors. Must match the dimensionality of `inputs`, Learnable scale factors. The shape must match the shape of `inputs`,
but have sizes of `1` for all axes normalized over (i.e., in the first except for the axes in `axes`. These axes should be set to 1 or be
dimension for ``axes='per-activation'``, and additionally in all skipped altogether (such that `gamma.ndim == inputs.ndim - len(axes)`).
dimensions past the second for ``axes='spatial'``).
beta : tensor beta : tensor
Learnable biases. Must match the tensor layout of `gamma`. Learnable biases. Must match the tensor layout of `gamma`.
epsilon : float epsilon : float
...@@ -117,14 +134,14 @@ def batch_normalization_train(inputs, gamma, beta, axes='per-activation', ...@@ -117,14 +134,14 @@ def batch_normalization_train(inputs, gamma, beta, axes='per-activation',
``running_mean * (1 - r_a_factor) + batch mean * r_a_factor`` ``running_mean * (1 - r_a_factor) + batch mean * r_a_factor``
will be returned as one of the outputs of this function. will be returned as one of the outputs of this function.
`running_mean` and `running_var` should either both be given or `running_mean` and `running_var` should either both be given or
both be None. both be None. The shape should match that of `gamma` and `beta`.
running_var : tensor or None running_var : tensor or None
Previous value of the running variance. If this is given, the new value Previous value of the running variance. If this is given, the new value
``running_var * (1 - r_a_factor) + (m / (m - 1)) * batch var * r_a_factor`` ``running_var * (1 - r_a_factor) + (m / (m - 1)) * batch var * r_a_factor``
will be returned as one of the outputs of this function, will be returned as one of the outputs of this function,
where `m` is the product of lengths of the averaged-over dimensions. where `m` is the product of lengths of the averaged-over dimensions.
`running_mean` and `running_var` should either both be given or `running_mean` and `running_var` should either both be given or
both be None. both be None. The shape should match that of `gamma` and `beta`.
Returns Returns
------- -------
...@@ -166,62 +183,76 @@ def batch_normalization_train(inputs, gamma, beta, axes='per-activation', ...@@ -166,62 +183,76 @@ def batch_normalization_train(inputs, gamma, beta, axes='per-activation',
(m / (m - 1)) * var * running_average_factor (m / (m - 1)) * var * running_average_factor
""" """
ndim = inputs.ndim ndim = inputs.ndim
if gamma.ndim != ndim or beta.ndim != ndim: axes, non_bc_axes = _prepare_batch_normalization_axes(axes, ndim)
raise ValueError("gamma and beta must be of the same dimensionality "
"as inputs; got %d and %d instead of %d" % # have the parameter tensors been broadcasted yet?
(gamma.ndim, beta.ndim, ndim)) if gamma.ndim == ndim:
params_ndim = ndim
else:
params_ndim = len(non_bc_axes)
params_dimshuffle_pattern = ['x'] * ndim
for i, axis in enumerate(non_bc_axes):
params_dimshuffle_pattern[axis] = i
if gamma.ndim != params_ndim or beta.ndim != params_ndim:
raise ValueError("gamma and beta dimensionality must match the "
"number of non-normalized axes, or have the "
"same number of dimensions as the inputs; "
"got %d and %d instead of %d" %
(gamma.ndim, beta.ndim, params_ndim))
if (running_mean is None) != (running_var is None): if (running_mean is None) != (running_var is None):
raise ValueError("running_mean and running_var must either both be " raise ValueError("running_mean and running_var must either both be "
"given or both be None") "given or both be None")
if running_mean is not None and running_mean.ndim != ndim: if running_mean is not None and running_mean.ndim != params_ndim:
raise ValueError("running_mean must be of the same dimensionality " raise ValueError("running_mean must be of the same dimensionality "
"as inputs; got %d instead of %d" % "as gamma and beta; got %d instead of %d" %
(running_mean.ndim, ndim)) (running_mean.ndim, params_ndim))
if running_var is not None and running_var.ndim != ndim: if running_var is not None and running_var.ndim != params_ndim:
raise ValueError("running_var must be of the same dimensionality " raise ValueError("running_var must be of the same dimensionality "
"as inputs; got %d instead of %d" % "as gamma and beta; got %d instead of %d" %
(running_var.ndim, ndim)) (running_var.ndim, params_ndim))
if epsilon < 1e-5: if epsilon < 1e-5:
raise ValueError("epsilon must be at least 1e-5, got %f" % epsilon) raise ValueError("epsilon must be at least 1e-5, got %f" % epsilon)
if axes == 'per-activation':
axes = (0,)
elif axes == 'spatial':
axes = (0,) + tuple(range(2, inputs.ndim))
elif isinstance(axes, (tuple, list, numpy.ndarray)):
axes = tuple(int(a) for a in axes)
else:
raise ValueError('invalid axes: %s', str(axes))
if len(axes) == 0:
raise ValueError('there should be at least one normalization axis')
if min(axes) < 0 or max(axes) >= ndim:
raise ValueError('axes should be less than ndim (<%d), but %s given' % (ndim, str(axes)))
inputs = as_tensor_variable(inputs) inputs = as_tensor_variable(inputs)
gamma = as_tensor_variable(gamma) gamma = as_tensor_variable(gamma)
beta = as_tensor_variable(beta) beta = as_tensor_variable(beta)
gamma = T.addbroadcast(gamma, *axes) if params_ndim != ndim:
beta = T.addbroadcast(beta, *axes) gamma = gamma.dimshuffle(params_dimshuffle_pattern)
beta = beta.dimshuffle(params_dimshuffle_pattern)
else:
gamma = T.addbroadcast(gamma, *axes)
beta = T.addbroadcast(beta, *axes)
batchnorm_op = AbstractBatchNormTrain(axes=axes) batchnorm_op = AbstractBatchNormTrain(axes=axes)
if running_mean is not None and running_var is not None: if running_mean is not None and running_var is not None:
running_mean = as_tensor_variable(running_mean) running_mean = as_tensor_variable(running_mean)
running_var = as_tensor_variable(running_var) running_var = as_tensor_variable(running_var)
running_mean_bc = T.addbroadcast(running_mean, *axes) if params_ndim != ndim:
running_var_bc = T.addbroadcast(running_var, *axes) running_mean = running_mean.dimshuffle(params_dimshuffle_pattern)
running_var = running_var.dimshuffle(params_dimshuffle_pattern)
else:
running_mean = T.addbroadcast(running_mean, *axes)
running_var = T.addbroadcast(running_var, *axes)
out, mean, invstd, new_running_mean, new_running_var = batchnorm_op( out, mean, invstd, new_running_mean, new_running_var = batchnorm_op(
inputs, gamma, beta, epsilon=epsilon, inputs, gamma, beta, epsilon=epsilon,
running_average_factor=running_average_factor, running_average_factor=running_average_factor,
running_mean=running_mean_bc, running_var=running_var_bc) running_mean=running_mean, running_var=running_var)
if new_running_mean.broadcastable != running_mean.broadcastable: if new_running_mean.broadcastable != running_mean.broadcastable:
new_running_mean = T.patternbroadcast(new_running_mean, running_mean.broadcastable) new_running_mean = T.patternbroadcast(new_running_mean, running_mean.broadcastable)
if new_running_var.broadcastable != running_var.broadcastable: if new_running_var.broadcastable != running_var.broadcastable:
new_running_var = T.patternbroadcast(new_running_var, running_var.broadcastable) new_running_var = T.patternbroadcast(new_running_var, running_var.broadcastable)
return out, mean, invstd, new_running_mean, new_running_var results = (out, mean, invstd, new_running_mean, new_running_var)
else: else:
return tuple(batchnorm_op(inputs, gamma, beta, epsilon=epsilon)) results = batchnorm_op(inputs, gamma, beta, epsilon=epsilon)
if params_ndim != ndim:
# remove the broadcasted dimensions (except from the output)
results = ([results[0]] +
[r.dimshuffle(non_bc_axes) for r in results[1:]])
return tuple(results)
def batch_normalization_test(inputs, gamma, beta, mean, var, def batch_normalization_test(inputs, gamma, beta, mean, var,
...@@ -239,10 +270,9 @@ def batch_normalization_test(inputs, gamma, beta, mean, var, ...@@ -239,10 +270,9 @@ def batch_normalization_test(inputs, gamma, beta, mean, var,
(i.e., all dimensions past the second), which for 4D inputs would be (i.e., all dimensions past the second), which for 4D inputs would be
equal to ``axes=(0, 2, 3)``. equal to ``axes=(0, 2, 3)``.
gamma : tensor gamma : tensor
Scale factors. Must match the dimensionality of `inputs`, but have Scale factors. The shape must match the shape of `inputs`,
sizes of `1` for all axes normalized over (i.e., in the first dimension except for the axes in `axes`. These axes should be set to 1 or be
for ``axes='per-activation'``, and additionally in all dimensions past skipped altogether (such that `gamma.ndim == inputs.ndim - len(axes)`).
the second for ``axes='spatial'``).
beta : tensor beta : tensor
Biases. Must match the tensor layout of `gamma`. Biases. Must match the tensor layout of `gamma`.
mean : tensor mean : tensor
...@@ -278,39 +308,45 @@ def batch_normalization_test(inputs, gamma, beta, mean, var, ...@@ -278,39 +308,45 @@ def batch_normalization_test(inputs, gamma, beta, mean, var,
out = (inputs - mean) * gamma / T.sqrt(var + epsilon) + beta out = (inputs - mean) * gamma / T.sqrt(var + epsilon) + beta
""" """
ndim = inputs.ndim ndim = inputs.ndim
if gamma.ndim != ndim or beta.ndim != ndim: axes, non_bc_axes = _prepare_batch_normalization_axes(axes, ndim)
raise ValueError("gamma and beta must be of the same dimensionality "
"as inputs; got %d and %d instead of %d" % # have the parameter tensors been broadcasted yet?
(gamma.ndim, beta.ndim, ndim)) if gamma.ndim == ndim:
if mean.ndim != ndim or var.ndim != ndim: params_ndim = ndim
else:
params_ndim = len(non_bc_axes)
params_dimshuffle_pattern = ['x'] * ndim
for i, axis in enumerate(non_bc_axes):
params_dimshuffle_pattern[axis] = i
if gamma.ndim != params_ndim or beta.ndim != params_ndim:
raise ValueError("gamma and beta dimensionality must match the "
"number of non-normalized axes, or have the "
"same number of dimensions as the inputs; "
"got %d and %d instead of %d" %
(gamma.ndim, beta.ndim, params_ndim))
if mean.ndim != params_ndim or var.ndim != params_ndim:
raise ValueError("mean and var must be of the same dimensionality " raise ValueError("mean and var must be of the same dimensionality "
"as inputs; got %d and %d instead of %d" % "as gamma and beta; got %d and %d instead of %d" %
(mean.ndim, var.ndim, ndim)) (mean.ndim, var.ndim, params_ndim))
if epsilon < 1e-5: if epsilon < 1e-5:
raise ValueError("epsilon must be at least 1e-5, got %f" % epsilon) raise ValueError("epsilon must be at least 1e-5, got %f" % epsilon)
if axes == 'per-activation':
axes = (0,)
elif axes == 'spatial':
axes = (0,) + tuple(range(2, inputs.ndim))
elif isinstance(axes, (tuple, list, numpy.ndarray)):
axes = tuple(int(a) for a in axes)
else:
raise ValueError('invalid axes: %s', str(axes))
if len(axes) == 0:
raise ValueError('there should be at least one normalization axis')
if min(axes) < 0 or max(axes) >= ndim:
raise ValueError('axes should be less than ndim (<%d), but %s given' % (ndim, str(axes)))
gamma = as_tensor_variable(gamma) gamma = as_tensor_variable(gamma)
beta = as_tensor_variable(beta) beta = as_tensor_variable(beta)
mean = as_tensor_variable(mean) mean = as_tensor_variable(mean)
var = as_tensor_variable(var) var = as_tensor_variable(var)
gamma = T.addbroadcast(gamma, *axes) if params_ndim != ndim:
beta = T.addbroadcast(beta, *axes) gamma = gamma.dimshuffle(params_dimshuffle_pattern)
mean = T.addbroadcast(mean, *axes) beta = beta.dimshuffle(params_dimshuffle_pattern)
var = T.addbroadcast(var, *axes) mean = mean.dimshuffle(params_dimshuffle_pattern)
var = var.dimshuffle(params_dimshuffle_pattern)
else:
gamma = T.addbroadcast(gamma, *axes)
beta = T.addbroadcast(beta, *axes)
mean = T.addbroadcast(mean, *axes)
var = T.addbroadcast(var, *axes)
batchnorm_op = AbstractBatchNormInference(axes=axes) batchnorm_op = AbstractBatchNormInference(axes=axes)
return batchnorm_op(inputs, gamma, beta, mean, var, epsilon=epsilon) return batchnorm_op(inputs, gamma, beta, mean, var, epsilon=epsilon)
......
...@@ -252,6 +252,86 @@ def test_batch_normalization_train_without_running_averages(): ...@@ -252,6 +252,86 @@ def test_batch_normalization_train_without_running_averages():
f(X, Scale, Bias, Dy) f(X, Scale, Bias, Dy)
def test_batch_normalization_train_broadcast():
for axes in ('per-activation', 'spatial', (1, 2, 3, 4)):
for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix, T.vector):
x = vartype('x')
ndim = x.ndim
eps = 5e-3 # some non-standard value to test if it's used
running_average_factor = 0.3
# remove non-existing axes
if isinstance(axes, tuple):
axes = tuple(i for i in axes if i < ndim)
if len(axes) == 0:
continue
# convert axes to explicit list
if axes == 'per-activation':
axes2 = (0,)
elif axes == 'spatial':
axes2 = (0,) + tuple(range(2, ndim))
else:
axes2 = axes
# compute axes for parameter tensors
non_bc_axes = tuple(i for i in range(ndim) if i not in axes2)
params_dimshuffle = ['x'] * ndim
for i, axis in enumerate(non_bc_axes):
params_dimshuffle[axis] = i
# construct non-broadcasted parameter variables
param_type = T.TensorType(x.dtype, (False,) * len(non_bc_axes))
scale, bias, running_mean, running_var = (param_type(n)
for n in ('scale', 'bias',
'running_mean',
'running_var'))
# broadcast parameter variables
scale_bc = scale.dimshuffle(params_dimshuffle)
bias_bc = bias.dimshuffle(params_dimshuffle)
running_mean_bc = running_mean.dimshuffle(params_dimshuffle)
running_var_bc = running_var.dimshuffle(params_dimshuffle)
# batch_normalization_train with original, non-broadcasted variables
train_non_bc = \
bn.batch_normalization_train(
x, scale, bias, axes, eps,
running_average_factor, running_mean, running_var)
# batch_normalization_train with broadcasted variables
train_bc = \
bn.batch_normalization_train(
x, scale_bc, bias_bc, axes, eps,
running_average_factor, running_mean_bc, running_var_bc)
train_bc = tuple([train_bc[0]] + # out
[r.dimshuffle(non_bc_axes) for r in train_bc[1:]])
# batch_normalization_test with original, non-broadcasted variables
test_non_bc = \
bn.batch_normalization_test(
x, scale, bias, running_mean, running_var, axes, eps)
# batch_normalization_test with broadcasted variables
test_bc = \
bn.batch_normalization_test(
x, scale_bc, bias_bc, running_mean_bc, running_var_bc, axes, eps)
# subtract the results of the non-broadcasted and broadcasted calls
results_non_bc = train_non_bc + (test_non_bc,)
results_bc = train_bc + (test_bc,)
results = [abs(r - r_bc) for (r, r_bc) in zip(results_non_bc, results_bc)]
# compile to compute all differences
f = theano.function([x, scale, bias, running_mean, running_var],
T.sum(sum(results)), mode='FAST_RUN')
# the paired ops are exactly the same, so the optimizer should have
# collapsed the sum of differences to a constant zero
nodes = f.maker.fgraph.toposort()
assert len(nodes) == 1
assert isinstance(nodes[0].op, theano.compile.DeepCopyOp)
assert 0.0 == theano.tensor.get_scalar_constant_value(nodes[0].inputs[0])
def test_batch_normalization_test(): def test_batch_normalization_test():
for axes in ('per-activation', 'spatial', (1, 2, 3, 4)): for axes in ('per-activation', 'spatial', (1, 2, 3, 4)):
for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix, T.vector): for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix, T.vector):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论