提交 9ad04124 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Abstract batchnorm in pure Python and pure Theano.

上级 3343d912
from __future__ import absolute_import, print_function, division
import numpy
import theano
from theano import Apply, Op
from theano.gof import local_optimizer
from theano.tensor import as_tensor_variable, TensorType
from theano.tensor import basic as T
from theano.tensor.opt import register_specialize_device
from theano.scalar import Composite
from theano.scalar import add, sub, true_div, mul
......@@ -75,3 +81,432 @@ def batch_normalization(inputs, gamma, beta, mean, std,
raise ValueError(
'mode must be either "low_mem", "high_mem"')
return rval
def batch_normalization_train(inputs, gamma, beta, axes='per-activation',
epsilon=1e-4):
"""
Performs batch normalization of the given inputs, using the mean and
variance of the inputs.
Parameters
----------
axes : 'per-activation', 'spatial' or a tuple of ints
The axes along which the input should be normalized. ``'per-activation'``
normalizes per activation and is equal to ``axes=(0,)``.
``'spatial'`` shares normalization factors across spatial dimensions
(i.e., all dimensions past the second), which for 4D inputs would be
equal to ``axes=(0,2,3)``.
gamma : tensor
Learnable scale factors. Must match the dimensionality of `inputs`,
but have sizes of `1` for all axes normalized over (i.e., in the first
dimension for ``mode='per-activation'`, and additionally in all
dimensions past the second for ``mode='spatial'``).
beta : tensor
Learnable biases. Must match the tensor layout of `gamma`.
epsilon : float
Epsilon value used in the batch normalization formula. Minimum allowed
value is 1e-5 (imposed by cuDNN).
Returns
-------
out : tensor
Batch-normalized inputs.
mean : tensor
Means of `inputs` across the normalization axes.
stdinv : tensor
Inverse standard deviations of `inputs` across the normalization axes.
Notes
-----
Requires cuDNN 5 and Theano 0.9dev2 or more recent.
For 4d tensors, returned values are equivalent to:
.. code-block:: python
# for 'per-activation'
axes = (0,)
# for 'spatial'
axes = (0, 2, 3)
mean = inputs.mean(axes, keepdims=True)
stdinv = T.inv(T.sqrt(inputs.var(axes, keepdims=True) + epsilon))
out = (inputs - mean) * gamma * stdinv + beta
For 5d tensors, the axes are (0, 2, 3, 4).
"""
ndim = inputs.ndim
if gamma.ndim != ndim or beta.ndim != ndim:
raise ValueError("gamma and beta must be of the same dimensionality "
"as inputs; got %d and %d instead of %d" %
(gamma.ndim, beta.ndim, ndim))
if epsilon < 1e-5:
raise ValueError("epsilon must be at least 1e-5, got %f" % epsilon)
if axes == 'per-activation':
axes = (0,)
elif axes == 'spatial':
axes = (0,) + tuple(range(2, inputs.ndim))
elif isinstance(axes, (tuple, list, numpy.ndarray)):
axes = tuple(int(a) for a in axes)
else:
raise ValueError('invalid axes: %s', str(axes))
if len(axes) == 0:
raise ValueError('there should be at least one normalization axis')
if min(axes) < 0 or max(axes) >= ndim:
raise ValueError('axes should be less than ndim (<%d), but %s given' % (ndim, str(axes)))
inputs = as_tensor_variable(inputs)
gamma = as_tensor_variable(gamma)
beta = as_tensor_variable(beta)
gamma = T.addbroadcast(gamma, *axes)
beta = T.addbroadcast(beta, *axes)
batchnorm_op = AbstractBatchNormTrain(axes=axes)
return tuple(batchnorm_op(inputs, gamma, beta, epsilon=epsilon))
def batch_normalization_test(inputs, gamma, beta, mean, var,
axes='per-activation', epsilon=1e-4):
"""
Performs batch normalization of the given inputs, using the given mean and
variance.
Parameters
----------
axes : 'per-activation', 'spatial' or a tuple of ints
The axes along which the input should be normalized. ``'per-activation'``
normalizes per activation and is equal to ``axes=(0,)``.
``'spatial'`` shares normalization factors across spatial dimensions
(i.e., all dimensions past the second), which for 4D inputs would be
equal to ``axes=(0,2,3)``.
gamma : tensor
Scale factors. Must match the dimensionality of `inputs`, but have
sizes of `1` for all axes normalized over (i.e., in the first dimension
for ``mode='per-activation'`, and additionally in all dimensions past
the second for ``mode='spatial'``).
beta : tensor
Biases. Must match the tensor layout of `gamma`.
mean : tensor
Means. Usually these are running averages computed during training.
Must match the tensor layout of `gamma`.
var : tensor
Variances. Usually these are running averages computed during training.
Must match the tensor layout of `gamma`.
epsilon : float
Epsilon value used in the batch normalization formula. Minimum allowed
value is 1e-5 (imposed by cuDNN).
Returns
-------
out : tensor
Batch-normalized inputs.
Notes
-----
This operation will use the cuDNN implementation if this is available.
(Requires cuDNN 5 or newer.)
For 4d tensors, the returned value is equivalent to:
.. code-block:: python
# for 'per-activation'
axes = (0,)
# for 'spatial'
axes = (0, 2, 3)
gamma, beta, mean, var = (T.addbroadcast(t, *axes)
for t in (gamma, beta, mean, var))
out = (inputs - mean) * gamma / T.sqrt(var + epsilon) + beta
For 5d tensors, the axes would be (0, 2, 3, 4).
"""
ndim = inputs.ndim
if gamma.ndim != ndim or beta.ndim != ndim:
raise ValueError("gamma and beta must be of the same dimensionality "
"as inputs; got %d and %d instead of %d" %
(gamma.ndim, beta.ndim, ndim))
if mean.ndim != ndim or var.ndim != ndim:
raise ValueError("mean and var must be of the same dimensionality "
"as inputs; got %d and %d instead of %d" %
(mean.ndim, var.ndim, ndim))
if epsilon < 1e-5:
raise ValueError("epsilon must be at least 1e-5, got %f" % epsilon)
if axes == 'per-activation':
axes = (0,)
elif axes == 'spatial':
axes = (0,) + tuple(range(2, inputs.ndim))
elif isinstance(axes, (tuple, list, numpy.ndarray)):
axes = tuple(int(a) for a in axes)
else:
raise ValueError('invalid axes: %s', str(axes))
if len(axes) == 0:
raise ValueError('there should be at least one normalization axis')
if min(axes) < 0 or max(axes) >= ndim:
raise ValueError('axes should be less than ndim (<%d), but %s given' % (ndim, str(axes)))
gamma = as_tensor_variable(gamma)
beta = as_tensor_variable(beta)
mean = as_tensor_variable(mean)
var = as_tensor_variable(var)
gamma = T.addbroadcast(gamma, *axes)
beta = T.addbroadcast(beta, *axes)
mean = T.addbroadcast(mean, *axes)
var = T.addbroadcast(var, *axes)
batchnorm_op = AbstractBatchNormInference(axes=axes)
return batchnorm_op(inputs, gamma, beta, mean, var, epsilon=epsilon)
class AbstractBatchNormTrain(Op):
"""
Abstract Op for Batch Normalization.
Parameters
----------
axes : a tuple of ints
The axes along which the input should be normalized.
x : tensor
The input to be normalized along `axes`.
scale : tensor
`scale` should have the same number of dimensions as `x`.
All dimensions listed in `axes` should have length 1.
bias : tensor
`bias` should have the same number of dimensions as `x`.
All dimensions listed in `axes` should have length 1.
epsilon
Epsilon value used in the batch normalization formula. Minimum allowed
value is 1e-5 (imposed by cuDNN).
"""
__props__ = ('axes',)
def __init__(self, axes=(0,)):
assert isinstance(axes, (tuple, list))
assert len(axes) > 0
axes = tuple(int(a) for a in axes)
self.axes = axes
def infer_shape(self, node, shape):
return [shape[0], shape[1], shape[1]]
def make_node(self, x, scale, bias, epsilon=1e-4):
assert x.ndim == scale.ndim == bias.ndim
if not isinstance(epsilon, theano.Variable):
epsilon = as_tensor_variable(epsilon)
return Apply(self, [x, scale, bias, epsilon], [x.type(), scale.type(), scale.type()])
def grad(self, inputs, grads):
x, scale, bias, epsilon = inputs
dy = grads[0]
_, x_mean, x_invstd = self(x, scale, bias, epsilon)
return AbstractBatchNormTrainGrad(self.axes)(
x, dy, scale, x_mean, x_invstd, epsilon) + [theano.gradient.DisconnectedType()()]
def connection_pattern(self, node):
# Specificy that epsilon is not connected to outputs.
return [[True, True, True], [True, True, True], [True, True, True],
[False, False, False]]
def perform(self, node, inputs, output_storage):
x, scale, bias, epsilon = inputs
axes = self.axes
if min(axes) < 0 or max(axes) >= x.ndim:
raise ValueError('axes should be less than ndim (<%d), but %s given' % (x.ndim, str(axes)))
mean = x.mean(axes, keepdims=True)
stdinv = 1.0 / numpy.sqrt(x.var(axes, keepdims=True) + epsilon)
out = (x - mean) * (scale * stdinv) + bias
output_storage[0][0] = out
output_storage[1][0] = mean
output_storage[2][0] = stdinv
class AbstractBatchNormInference(Op):
"""
Abstract Op for Batch Normalization.
Parameters
----------
axes : a tuple of ints
The axes along which the input is normalized.
epsilon
Epsilon value used in the batch normalization formula. Minimum allowed
value is 1e-5 (imposed by cuDNN).
"""
__props__ = ('axes',)
def __init__(self, axes=(0,)):
assert isinstance(axes, (tuple, list))
assert len(axes) > 0
axes = tuple(int(a) for a in axes)
self.axes = axes
def infer_shape(self, node, shape):
return [shape[0]]
def make_node(self, x, scale, bias, estimated_mean, estimated_variance, epsilon=1e-4):
assert x.ndim == scale.ndim == bias.ndim == estimated_mean.ndim == estimated_variance.ndim
if not isinstance(epsilon, theano.Variable):
epsilon = as_tensor_variable(epsilon)
return Apply(self, [x, scale, bias, estimated_mean, estimated_variance, epsilon], [x.type()])
def grad(self, inputs, grads):
x, scale, bias, est_mean, est_var, epsilon = inputs
dy = grads[0]
axes = self.axes
if min(axes) < 0 or max(axes) >= x.ndim:
raise ValueError('axes should be less than ndim (<%d), but %s given' % (x.ndim, str(axes)))
scale, bias, est_mean, est_var = (theano.tensor.addbroadcast(t, *axes)
for t in (scale, bias, est_mean, est_var))
# define helper expressions
est_var_eps = est_var + epsilon
est_std = theano.tensor.sqrt(est_var_eps)
two = theano.tensor.constant(2.)
# define and return gradients
dx = dy * (scale / est_std)
dscale = (dy * (x - est_mean)).sum(axes, keepdims=True) / est_std
dbias = dy.sum(axes, keepdims=True)
dmean = -dy.sum(axes, keepdims=True) * (scale / est_std)
dvar = -(dy * (x - est_mean)).sum(axes, keepdims=True) * (scale / (two * est_var_eps * est_std))
return [dx, dscale, dbias, dmean, dvar, theano.gradient.DisconnectedType()()]
def connection_pattern(self, node):
# Specificy that epsilon is not connected to outputs.
return [[True], [True], [True], [True], [True], [False]]
def perform(self, node, inputs, output_storage):
x, scale, bias, estimated_mean, estimated_variance, epsilon = inputs
out = (x - estimated_mean) * (scale / numpy.sqrt(estimated_variance + epsilon)) + bias
output_storage[0][0] = out
class AbstractBatchNormTrainGrad(Op):
__props__ = ('axes',)
def __init__(self, axes=(0,)):
assert isinstance(axes, (tuple, list))
assert len(axes) > 0
axes = tuple(int(a) for a in axes)
self.axes = axes
def make_node(self, x, dy, scale, x_mean, x_invstd, epsilon=1e-4):
assert x.ndim == dy.ndim == scale.ndim == x_mean.ndim == x_invstd.ndim
if not isinstance(epsilon, theano.Variable):
epsilon = as_tensor_variable(epsilon)
return Apply(self, [x, dy, scale, x_mean, x_invstd, epsilon],
[x.type(), scale.type(), scale.type()])
def infer_shape(self, node, shape):
return [shape[0], shape[2], shape[2]]
def perform(self, node, inputs, output_storage):
x, dy, scale, x_mean, x_invstd, epsilon = inputs
axes = self.axes
if min(axes) < 0 or max(axes) >= x.ndim:
raise ValueError('axes should be less than ndim (<%d), but %s given' % (x.ndim, str(axes)))
x_diff = x - x_mean
mean_dy_x_diff = numpy.mean(dy * x_diff, axis=axes, keepdims=True)
c = (dy * x_invstd) - (x_diff * mean_dy_x_diff * (x_invstd ** 3))
g_wrt_inputs = scale * (c - numpy.mean(c, axis=axes, keepdims=True))
g_wrt_scale = numpy.sum(dy * x_invstd * x_diff, axis=axes, keepdims=True)
g_wrt_bias = numpy.sum(dy, axis=axes, keepdims=True)
output_storage[0][0] = g_wrt_inputs
output_storage[1][0] = g_wrt_scale
output_storage[2][0] = g_wrt_bias
@local_optimizer([AbstractBatchNormTrain])
def local_abstract_batch_norm_train(node):
if not isinstance(node.op, AbstractBatchNormTrain):
return None
x, scale, bias, epsilon = node.inputs
axes = node.op.axes
if min(axes) < 0 or max(axes) > x.ndim:
return None
if not isinstance(x.type, TensorType) or \
not isinstance(scale.type, TensorType) or \
not isinstance(bias.type, TensorType) or \
not isinstance(epsilon.type, TensorType):
return None
mean = x.mean(axes, keepdims=True)
stdinv = T.inv(T.sqrt(x.var(axes, keepdims=True) + epsilon))
out = (x - mean) * (scale * stdinv) + bias
# TODO copy_stack_trace?
return [out, mean, stdinv]
@local_optimizer([AbstractBatchNormTrainGrad])
def local_abstract_batch_norm_train_grad(node):
if not isinstance(node.op, AbstractBatchNormTrainGrad):
return None
x, dy, scale, x_mean, x_invstd, epsilon = node.inputs
axes = node.op.axes
if min(axes) < 0 or max(axes) > x.ndim:
return None
if not isinstance(x.type, TensorType) or \
not isinstance(dy.type, TensorType) or \
not isinstance(scale.type, TensorType) or \
not isinstance(x_mean.type, TensorType) or \
not isinstance(x_invstd.type, TensorType) or \
not isinstance(epsilon.type, TensorType):
return None
x_diff = x - x_mean
mean_dy_x_diff = T.mean(dy * x_diff, axis=axes, keepdims=True)
c = (dy * x_invstd) - x_diff * (mean_dy_x_diff * (x_invstd ** 3))
g_wrt_inputs = scale * (c - T.mean(c, axis=axes, keepdims=True))
g_wrt_scale = T.sum(dy * x_invstd * x_diff, axis=axes, keepdims=True)
g_wrt_bias = T.sum(dy, axis=axes, keepdims=True)
# TODO copy_stack_trace?
return [g_wrt_inputs, g_wrt_scale, g_wrt_bias]
@local_optimizer([AbstractBatchNormInference])
def local_abstract_batch_norm_inference(node):
if not isinstance(node.op, AbstractBatchNormInference):
return None
x, scale, bias, estimated_mean, estimated_variance, epsilon = node.inputs
if not isinstance(x.type, TensorType) or \
not isinstance(scale.type, TensorType) or \
not isinstance(bias.type, TensorType) or \
not isinstance(estimated_mean.type, TensorType) or \
not isinstance(estimated_variance.type, TensorType) or \
not isinstance(epsilon.type, TensorType):
return None
# TODO copy_stack_trace?
return [(x - estimated_mean) * (scale / T.sqrt(estimated_variance + epsilon)) + bias]
# Register Cpu Optmization
bn_groupopt = theano.gof.optdb.LocalGroupDB()
bn_groupopt.__name__ = 'batchnorm_opts'
register_specialize_device(bn_groupopt, 'fast_compile', 'fast_run')
bn_groupopt.register('local_abstract_batch_norm_train',
local_abstract_batch_norm_train, 30,
'fast_compile', 'fast_run')
bn_groupopt.register('local_abstract_batch_norm_train_grad',
local_abstract_batch_norm_train_grad, 30,
'fast_compile', 'fast_run')
bn_groupopt.register('local_abstract_batch_norm_inference',
local_abstract_batch_norm_inference, 30,
'fast_compile', 'fast_run')
from __future__ import absolute_import, print_function, division
import theano
import theano.tensor as T
from theano.tests import unittest_tools as utt
import numpy
from theano.tensor.nnet.bn import batch_normalization
from theano.tensor.nnet import bn
def test_BNComposite():
......@@ -39,7 +40,7 @@ def test_BNComposite():
f_ref = theano.function([x, b, g, m, v], [bn_ref_op])
res_ref = f_ref(X, G, B, M, V)
for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x, g, b, m, v, mode=mode)
bn_op = bn.batch_normalization(x, g, b, m, v, mode=mode)
f = theano.function([x, b, g, m, v], [bn_op])
res = f(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
......@@ -47,7 +48,7 @@ def test_BNComposite():
theano.config.compute_test_value = orig
def test_bn():
def test_batch_normalization():
def bn_ref(x, G, B, M, V):
n = (x - M) / V
......@@ -70,28 +71,28 @@ def test_bn():
f_ref = theano.function([x, b, g, m, v], [bn_ref_op])
res_ref = f_ref(X, G, B, M, V)
for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x, g, b, m, v, mode=mode)
bn_op = bn.batch_normalization(x, g, b, m, v, mode=mode)
f = theano.function([x, b, g, m, v], [bn_op])
res = f(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
def bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs, gamma, beta, mean, std, mode=mode)
utt.verify_grad(bn, [X, G, B, M, V])
def bn_f(inputs, gamma, beta, mean, std):
return bn.batch_normalization(inputs, gamma, beta, mean, std, mode=mode)
utt.verify_grad(bn_f, [X, G, B, M, V])
bn_ref_op = bn_ref(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True))
f_ref = theano.function([x, b, g], [bn_ref_op])
res_ref = f_ref(X, G, B)
for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True), mode=mode)
bn_op = bn.batch_normalization(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True), mode=mode)
f = theano.function([x, b, g], [bn_op])
res = f(X, G, B)
utt.assert_allclose(res_ref, res)
def bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs, gamma, beta, mean, std, mode=mode)
utt.verify_grad(batch_normalization, [X, G, B,
X.mean(axis=0)[numpy.newaxis], X.std(axis=0)[numpy.newaxis]])
def bn_f(inputs, gamma, beta, mean, std):
return bn.batch_normalization(inputs, gamma, beta, mean, std, mode=mode)
utt.verify_grad(bn_f, [X, G, B,
X.mean(axis=0)[numpy.newaxis], X.std(axis=0)[numpy.newaxis]])
def test_bn_feature_maps():
......@@ -122,21 +123,147 @@ def test_bn_feature_maps():
res_ref = f_ref(X, G, B, M, V)
for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x,
g.dimshuffle('x', 0, 'x', 'x'),
b.dimshuffle('x', 0, 'x', 'x'),
m.dimshuffle('x', 0, 'x', 'x'),
v.dimshuffle('x', 0, 'x', 'x'),
mode=mode)
bn_op = bn.batch_normalization(x,
g.dimshuffle('x', 0, 'x', 'x'),
b.dimshuffle('x', 0, 'x', 'x'),
m.dimshuffle('x', 0, 'x', 'x'),
v.dimshuffle('x', 0, 'x', 'x'),
mode=mode)
f = theano.function([x, b, g, m, v], [bn_op])
res = f(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
def conv_bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs,
gamma.dimshuffle('x', 0, 'x', 'x'),
beta.dimshuffle('x', 0, 'x', 'x'),
mean.dimshuffle('x', 0, 'x', 'x'),
std.dimshuffle('x', 0, 'x', 'x'),
mode=mode)
return bn.batch_normalization(inputs,
gamma.dimshuffle('x', 0, 'x', 'x'),
beta.dimshuffle('x', 0, 'x', 'x'),
mean.dimshuffle('x', 0, 'x', 'x'),
std.dimshuffle('x', 0, 'x', 'x'),
mode=mode)
utt.verify_grad(conv_bn, [X, G, B, M, V])
def test_batch_normalization_train():
utt.seed_rng()
for axes in ('per-activation', 'spatial', (1, 2, 3, 4)):
for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix, T.vector):
x, scale, bias = (vartype(n) for n in ('x', 'scale', 'bias'))
ndim = x.ndim
eps = 5e-3 # some non-standard value to test if it's used
# remove non-existing axes
if isinstance(axes, tuple):
axes = tuple(i for i in axes if i < ndim)
if len(axes) == 0:
continue
# forward pass
out, x_mean, x_invstd = bn.batch_normalization_train(
x, scale, bias, axes, eps)
# reference forward pass
if axes == 'per-activation':
axes2 = (0,)
elif axes == 'spatial':
axes2 = (0,) + tuple(range(2, ndim))
else:
axes2 = axes
x_mean2 = x.mean(axis=axes2, keepdims=True)
x_invstd2 = T.inv(T.sqrt(x.var(axis=axes2, keepdims=True) + eps))
scale2 = T.addbroadcast(scale, *axes2)
bias2 = T.addbroadcast(bias, *axes2)
out2 = (x - x_mean2) * (scale2 * x_invstd2) + bias2
# backward pass
dy = vartype('dy')
grads = T.grad(None, wrt=[x, scale, bias], known_grads={out: dy})
# reference backward pass
grads2 = T.grad(None, wrt=[x, scale, bias], known_grads={out2: dy})
# compile
f = theano.function([x, scale, bias, dy],
[out, x_mean, x_invstd, out2, x_mean2, x_invstd2] +
grads + grads2, mode='FAST_RUN')
# check if the abstract Ops have been replaced
assert not any([isinstance(n.op, (bn.AbstractBatchNormTrain,
bn.AbstractBatchNormInference,
bn.AbstractBatchNormTrainGrad))
for n in f.maker.fgraph.toposort()])
# run
for data_shape in ((5, 10, 30, 40, 10), (4, 3, 1, 1, 1), (1, 1, 5, 5, 5)):
data_shape = data_shape[:ndim]
param_shape = tuple(1 if d in axes2 else s
for d, s in enumerate(data_shape))
X = 4 + 3 * numpy.random.randn(*data_shape).astype(theano.config.floatX)
Dy = -1 + 2 * numpy.random.randn(*data_shape).astype(theano.config.floatX)
Scale = numpy.random.randn(*param_shape).astype(theano.config.floatX)
Bias = numpy.random.randn(*param_shape).astype(theano.config.floatX)
outputs = f(X, Scale, Bias, Dy)
# compare outputs
utt.assert_allclose(outputs[0], outputs[0 + 3]) # out
utt.assert_allclose(outputs[1], outputs[1 + 3]) # mean
utt.assert_allclose(outputs[2], outputs[2 + 3]) # invstd
# compare gradients
utt.assert_allclose(outputs[6], outputs[6 + 3], atol=1e-4) # dx
utt.assert_allclose(outputs[7], outputs[7 + 3], rtol=2e-4, atol=1e-4) # dscale
utt.assert_allclose(outputs[8], outputs[8 + 3]) # dbias
def test_batch_normalization_test():
for axes in ('per-activation', 'spatial', (1, 2, 3, 4)):
for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix, T.vector):
x, scale, bias, mean, var = (vartype(n)
for n in ('x', 'scale', 'bias', 'mean', 'var'))
ndim = x.ndim
eps = 5e-3 # some non-standard value to test if it's used
# remove non-existing axes
if isinstance(axes, tuple):
axes = tuple(i for i in axes if i < ndim)
if len(axes) == 0:
continue
# forward pass
out = bn.batch_normalization_test(x, scale, bias, mean,
var, axes, eps)
# reference forward pass
if axes == 'per-activation':
axes2 = (0,)
elif axes == 'spatial':
axes2 = (0,) + tuple(range(2, ndim))
else:
axes2 = axes
scale2, bias2, mean2, var2 = (T.addbroadcast(t, *axes2)
for t in (scale, bias, mean, var))
out2 = (x - mean2) * (scale2 / T.sqrt(var2 + eps)) + bias2
# backward pass
dy = vartype('dy')
grads = T.grad(None, wrt=[x, scale, bias, mean, var], known_grads={out: dy})
# reference backward pass
grads2 = T.grad(None, wrt=[x, scale, bias, mean, var], known_grads={out2: dy})
# compile
f = theano.function([x, scale, bias, mean, var, dy],
[out, out2] + grads + grads2, mode='FAST_RUN')
# check if the abstract Ops have been replaced
assert not any([isinstance(n.op, (bn.AbstractBatchNormTrain,
bn.AbstractBatchNormInference,
bn.AbstractBatchNormTrainGrad))
for n in f.maker.fgraph.toposort()])
# run
for data_shape in ((10, 20, 30, 40, 10), (4, 3, 1, 1, 1), (1, 1, 5, 5, 5)):
data_shape = data_shape[:ndim]
param_shape = tuple(1 if d in axes2 else s
for d, s in enumerate(data_shape))
X = 4 + 3 * numpy.random.randn(*data_shape).astype(theano.config.floatX)
Dy = -1 + 2 * numpy.random.randn(*data_shape).astype(theano.config.floatX)
Scale = numpy.random.randn(*param_shape).astype(theano.config.floatX)
Bias = numpy.random.randn(*param_shape).astype(theano.config.floatX)
Mean = numpy.random.randn(*param_shape).astype(theano.config.floatX)
Var = numpy.random.rand(*param_shape).astype(theano.config.floatX)
outputs = f(X, Scale, Bias, Mean, Var, Dy)
# compare outputs
utt.assert_allclose(outputs[0], outputs[1]) # out
# compare gradients
utt.assert_allclose(outputs[2], outputs[2 + 5], atol=4e-5) # dx
utt.assert_allclose(outputs[3], outputs[3 + 5], atol=4e-5) # dscale
utt.assert_allclose(outputs[4], outputs[4 + 5]) # dbias
utt.assert_allclose(outputs[5], outputs[5 + 5]) # dmean
utt.assert_allclose(outputs[6], outputs[6 + 5], rtol=2e-3, atol=4e-5) # dvar
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论