提交 3f57f90b authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Accept non-broadcasted inputs for batch normalization.

上级 9c90f281
差异被折叠。
......@@ -252,6 +252,86 @@ def test_batch_normalization_train_without_running_averages():
f(X, Scale, Bias, Dy)
def test_batch_normalization_train_broadcast():
for axes in ('per-activation', 'spatial', (1, 2, 3, 4)):
for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix, T.vector):
x = vartype('x')
ndim = x.ndim
eps = 5e-3 # some non-standard value to test if it's used
running_average_factor = 0.3
# remove non-existing axes
if isinstance(axes, tuple):
axes = tuple(i for i in axes if i < ndim)
if len(axes) == 0:
continue
# convert axes to explicit list
if axes == 'per-activation':
axes2 = (0,)
elif axes == 'spatial':
axes2 = (0,) + tuple(range(2, ndim))
else:
axes2 = axes
# compute axes for parameter tensors
non_bc_axes = tuple(i for i in range(ndim) if i not in axes2)
params_dimshuffle = ['x'] * ndim
for i, axis in enumerate(non_bc_axes):
params_dimshuffle[axis] = i
# construct non-broadcasted parameter variables
param_type = T.TensorType(x.dtype, (False,) * len(non_bc_axes))
scale, bias, running_mean, running_var = (param_type(n)
for n in ('scale', 'bias',
'running_mean',
'running_var'))
# broadcast parameter variables
scale_bc = scale.dimshuffle(params_dimshuffle)
bias_bc = bias.dimshuffle(params_dimshuffle)
running_mean_bc = running_mean.dimshuffle(params_dimshuffle)
running_var_bc = running_var.dimshuffle(params_dimshuffle)
# batch_normalization_train with original, non-broadcasted variables
train_non_bc = \
bn.batch_normalization_train(
x, scale, bias, axes, eps,
running_average_factor, running_mean, running_var)
# batch_normalization_train with broadcasted variables
train_bc = \
bn.batch_normalization_train(
x, scale_bc, bias_bc, axes, eps,
running_average_factor, running_mean_bc, running_var_bc)
train_bc = tuple([train_bc[0]] + # out
[r.dimshuffle(non_bc_axes) for r in train_bc[1:]])
# batch_normalization_test with original, non-broadcasted variables
test_non_bc = \
bn.batch_normalization_test(
x, scale, bias, running_mean, running_var, axes, eps)
# batch_normalization_test with broadcasted variables
test_bc = \
bn.batch_normalization_test(
x, scale_bc, bias_bc, running_mean_bc, running_var_bc, axes, eps)
# subtract the results of the non-broadcasted and broadcasted calls
results_non_bc = train_non_bc + (test_non_bc,)
results_bc = train_bc + (test_bc,)
results = [abs(r - r_bc) for (r, r_bc) in zip(results_non_bc, results_bc)]
# compile to compute all differences
f = theano.function([x, scale, bias, running_mean, running_var],
T.sum(sum(results)), mode='FAST_RUN')
# the paired ops are exactly the same, so the optimizer should have
# collapsed the sum of differences to a constant zero
nodes = f.maker.fgraph.toposort()
assert len(nodes) == 1
assert isinstance(nodes[0].op, theano.compile.DeepCopyOp)
assert 0.0 == theano.tensor.get_scalar_constant_value(nodes[0].inputs[0])
def test_batch_normalization_test():
for axes in ('per-activation', 'spatial', (1, 2, 3, 4)):
for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix, T.vector):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论