提交 4bc49179 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

update comment and tests

上级 a74c1b06
...@@ -52,9 +52,11 @@ def batch_normalization(inputs, gamma, beta, mean, std, ...@@ -52,9 +52,11 @@ def batch_normalization(inputs, gamma, beta, mean, std,
mode: 'low_mem' or 'high_mem' mode: 'low_mem' or 'high_mem'
Specify which batch_normalization implementation that will be Specify which batch_normalization implementation that will be
used. used.
As no intermediate representations are stored for the As no intermediate representations are stored for the back-propagation,
back-propagation, 'low_mem' implementation lower the memory usage, however, 'low_mem' implementation lower the memory usage, however,
it is 5-10% slower than 'high_mem' implementation. it is 5-10% slower than 'high_mem' implementation. Note that 5-10% computation
time difference compare the batch_normalization operation only, time difference
between implementation is likely to be less important on the full model fprop/bprop.
""" """
if mode == 'low_mem': if mode == 'low_mem':
elm_bn = theano.tensor.elemwise.Elemwise(scalar_op=BNComposite(dtype=inputs.dtype)) elm_bn = theano.tensor.elemwise.Elemwise(scalar_op=BNComposite(dtype=inputs.dtype))
......
...@@ -41,7 +41,7 @@ def test_bn(): ...@@ -41,7 +41,7 @@ def test_bn():
f_ref = theano.function([x, b, g], [bn_ref_op]) f_ref = theano.function([x, b, g], [bn_ref_op])
res_ref = f_ref(X, G, B) res_ref = f_ref(X, G, B)
for mode in ['low_mem', 'high_mem']: for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True)) bn_op = batch_normalization(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True), mode=mode)
f = theano.function([x, b, g], [bn_op]) f = theano.function([x, b, g], [bn_op])
res = f(X, G, B) res = f(X, G, B)
utt.assert_allclose(res_ref, res) utt.assert_allclose(res_ref, res)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论