提交 23ca6554 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

update tests

上级 36de8dd2
......@@ -34,7 +34,7 @@ def test_bn():
utt.verify_grad(batch_normalization, [X, G, B, M, V])
bn_op = batch_normalization(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True))
bn_ref_op = bn_ref(x, g, b, x.mean(axis=0, keepdims=True), x.var(axis=0, keepdims=True))
bn_ref_op = bn_ref(x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True))
f = theano.function([x, b, g], [bn_op])
f_ref = theano.function([x, b, g], [bn_ref_op])
res = f(X, G, B)
......@@ -66,7 +66,7 @@ def test_bn_feature_maps():
g.dimshuffle('x', 0, 'x', 'x'),
b.dimshuffle('x', 0, 'x', 'x'),
m.dimshuffle('x', 0, 'x', 'x'),
v.dimshuffle('x', 0, 'x', 'x'), axis=1)
v.dimshuffle('x', 0, 'x', 'x'))
bn_ref_op = bn_ref(x,
g.dimshuffle('x', 0, 'x', 'x'),
b.dimshuffle('x', 0, 'x', 'x'),
......@@ -78,11 +78,10 @@ def test_bn_feature_maps():
res_ref = f_ref(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
def conv_bn(inputs, gamma, beta, mean, variance):
def conv_bn(inputs, gamma, beta, mean, std):
return batch_normalization(inputs,
gamma.dimshuffle('x', 0, 'x', 'x'),
beta.dimshuffle('x', 0, 'x', 'x'),
mean.dimshuffle('x', 0, 'x', 'x'),
variance.dimshuffle('x', 0, 'x', 'x'),
axis=1)
std.dimshuffle('x', 0, 'x', 'x'))
utt.verify_grad(conv_bn, [X, G, B, M, V])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论