提交 b690578d authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Preserve broadcastable pattern in batchnorm optimizations.

上级 43411345
...@@ -627,6 +627,9 @@ def local_abstract_batch_norm_train(node): ...@@ -627,6 +627,9 @@ def local_abstract_batch_norm_train(node):
(m / (m - 1)) * var * running_average_factor (m / (m - 1)) * var * running_average_factor
results.append(running_var) results.append(running_var)
results = [T.patternbroadcast(r, r_orig.broadcastable)
for (r, r_orig) in zip(results, node.outputs)]
# TODO copy_stack_trace? # TODO copy_stack_trace?
return results return results
...@@ -655,8 +658,13 @@ def local_abstract_batch_norm_train_grad(node): ...@@ -655,8 +658,13 @@ def local_abstract_batch_norm_train_grad(node):
g_wrt_inputs = scale * (c - T.mean(c, axis=axes, keepdims=True)) g_wrt_inputs = scale * (c - T.mean(c, axis=axes, keepdims=True))
g_wrt_scale = T.sum(dy * x_invstd * x_diff, axis=axes, keepdims=True) g_wrt_scale = T.sum(dy * x_invstd * x_diff, axis=axes, keepdims=True)
g_wrt_bias = T.sum(dy, axis=axes, keepdims=True) g_wrt_bias = T.sum(dy, axis=axes, keepdims=True)
results = [g_wrt_inputs, g_wrt_scale, g_wrt_bias]
results = [T.patternbroadcast(r, r_orig.broadcastable)
for (r, r_orig) in zip(results, node.outputs)]
# TODO copy_stack_trace? # TODO copy_stack_trace?
return [g_wrt_inputs, g_wrt_scale, g_wrt_bias] return results
@local_optimizer([AbstractBatchNormInference]) @local_optimizer([AbstractBatchNormInference])
...@@ -674,8 +682,11 @@ def local_abstract_batch_norm_inference(node): ...@@ -674,8 +682,11 @@ def local_abstract_batch_norm_inference(node):
not isinstance(epsilon.type, TensorType): not isinstance(epsilon.type, TensorType):
return None return None
result = (x - estimated_mean) * (scale / T.sqrt(estimated_variance + epsilon)) + bias
result = T.patternbroadcast(result, node.outputs[0].broadcastable)
# TODO copy_stack_trace? # TODO copy_stack_trace?
return [(x - estimated_mean) * (scale / T.sqrt(estimated_variance + epsilon)) + bias] return [result]
# Register Cpu Optmization # Register Cpu Optmization
......
...@@ -392,3 +392,24 @@ def test_batch_normalization_test(): ...@@ -392,3 +392,24 @@ def test_batch_normalization_test():
utt.assert_allclose(outputs[4], outputs[4 + 5]) # dbias utt.assert_allclose(outputs[4], outputs[4 + 5]) # dbias
utt.assert_allclose(outputs[5], outputs[5 + 5]) # dmean utt.assert_allclose(outputs[5], outputs[5 + 5]) # dmean
utt.assert_allclose(outputs[6], outputs[6 + 5], rtol=2e-3, atol=4e-5) # dvar utt.assert_allclose(outputs[6], outputs[6 + 5], rtol=2e-3, atol=4e-5) # dvar
def test_batch_normalization_broadcastable():
# check if the broadcastable pattern is preserved by the optimizations
x, dy, scale, bias, mean, var = (T.scalar(n).dimshuffle(['x'] * 5)
for n in ('x', 'dy', 'scale', 'bias', 'mean', 'var'))
# forward pass
out_train, x_mean, x_invstd = bn.batch_normalization_train(x, scale, bias, 'spatial')
out_test = bn.batch_normalization_test(x, scale, bias, mean, var, 'spatial')
# backward pass
grads_train = T.grad(None, wrt=[x, scale, bias], known_grads={out_train: dy})
grads_test = T.grad(None, wrt=[x, scale, bias], known_grads={out_test: dy})
# compile
f = theano.function([x, scale, bias, mean, var, dy],
[out_train, x_mean, x_invstd, out_test] + grads_train + grads_test,
mode='FAST_RUN')
assert not any([isinstance(n.op, (bn.AbstractBatchNormTrain,
bn.AbstractBatchNormInference,
bn.AbstractBatchNormTrainGrad))
for n in f.maker.fgraph.toposort()])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论