提交 9c90f281 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Add inplace optimization for GpuBatchNormInference.

上级 8202b103
...@@ -1776,17 +1776,22 @@ class GpuDnnBatchNormInference(DnnBase): ...@@ -1776,17 +1776,22 @@ class GpuDnnBatchNormInference(DnnBase):
value is 1e-5 (imposed by cuDNN). value is 1e-5 (imposed by cuDNN).
""" """
__props__ = ('mode',) __props__ = ('mode', 'inplace')
def __init__(self, mode='per-activation'): def __init__(self, mode='per-activation', inplace=False):
DnnBase.__init__(self, ['dnn_batchnorm_base.c', 'dnn_batchnorm_inf.c'], DnnBase.__init__(self, ['dnn_batchnorm_base.c', 'dnn_batchnorm_inf.c'],
'dnn_batchnorm_op') 'dnn_batchnorm_op')
assert (mode in ('per-activation', 'spatial')) assert (mode in ('per-activation', 'spatial'))
self.mode = mode self.mode = mode
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
def get_op_params(self): def get_op_params(self):
params = [] params = []
if self.inplace:
params.append(('INPLACE_OUTPUT', '1'))
params.append(('MODE', ("CUDNN_BATCHNORM_SPATIAL" params.append(('MODE', ("CUDNN_BATCHNORM_SPATIAL"
if self.mode == "spatial" if self.mode == "spatial"
else "CUDNN_BATCHNORM_PER_ACTIVATION"))) else "CUDNN_BATCHNORM_PER_ACTIVATION")))
...@@ -3167,6 +3172,13 @@ def local_batch_norm_inplace_running_var(node): ...@@ -3167,6 +3172,13 @@ def local_batch_norm_inplace_running_var(node):
inplace_output=node.op.inplace_output)(*node.inputs) inplace_output=node.op.inplace_output)(*node.inputs)
@register_inplace()
@local_optimizer([GpuDnnBatchNormInference], inplace=True)
def local_batch_norm_inference_inplace(node):
if isinstance(node.op, GpuDnnBatchNormInference) and not node.op.inplace:
return [GpuDnnBatchNormInference(mode=node.op.mode, inplace=True)(*node.inputs)]
@local_optimizer([bn.AbstractBatchNormTrainGrad]) @local_optimizer([bn.AbstractBatchNormTrainGrad])
def local_abstract_batch_norm_train_grad_cudnn(node): def local_abstract_batch_norm_train_grad_cudnn(node):
if not isinstance(node.op, bn.AbstractBatchNormTrainGrad): if not isinstance(node.op, bn.AbstractBatchNormTrainGrad):
......
...@@ -14,8 +14,14 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale, ...@@ -14,8 +14,14 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale,
if (epsilon < 1e-5) if (epsilon < 1e-5)
return 1; return 1;
#ifdef INPLACE_OUTPUT
Py_XDECREF(*outp);
*outp = inp;
Py_INCREF(*outp);
#else
if (theano_prep_output(outp, inp->ga.nd, inp->ga.dimensions, inp->ga.typecode, GA_C_ORDER, c) != 0) if (theano_prep_output(outp, inp->ga.nd, inp->ga.dimensions, inp->ga.typecode, GA_C_ORDER, c) != 0)
return 1; return 1;
#endif
if (c_set_tensorNd(*outp, bn_output) != 0) if (c_set_tensorNd(*outp, bn_output) != 0)
return 1; return 1;
......
...@@ -1656,6 +1656,36 @@ def test_batchnorm_inference(): ...@@ -1656,6 +1656,36 @@ def test_batchnorm_inference():
utt.assert_allclose(outputs_abstract[5], outputs_ref[5], rtol=2e-3, atol=4e-5) # dvar utt.assert_allclose(outputs_abstract[5], outputs_ref[5], rtol=2e-3, atol=4e-5) # dvar
def test_batchnorm_inference_inplace():
# test inplace
if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg)
if dnn.version(raises=False) < 5000:
raise SkipTest("batch normalization requires cudnn v5+")
utt.seed_rng()
x, scale, bias, mean, var = (T.tensor4(n) for n in ('x', 'scale', 'bias', 'mean', 'var'))
data_shape = (5, 10, 30, 25)
param_shape = (1, 10, 30, 25)
out = dnn.dnn_batch_normalization_test(x, scale, bias, mean, var)
f = theano.function([x, scale, bias, mean, var], [out], mode=mode_with_gpu)
# check for the inplace settings
nodes = [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, dnn.GpuDnnBatchNormInference)]
assert len(nodes) == 1
assert nodes[0].op.inplace
# run
X = 4 + 3 * numpy.random.randn(*data_shape).astype(theano.config.floatX)
Scale = numpy.random.randn(*param_shape).astype(theano.config.floatX)
Bias = numpy.random.randn(*param_shape).astype(theano.config.floatX)
Mean = numpy.random.randn(*param_shape).astype(theano.config.floatX)
Var = numpy.random.rand(*param_shape).astype(theano.config.floatX)
f(X, Scale, Bias, Mean, Var)
def test_dnn_batchnorm_valid_and_invalid_axes(): def test_dnn_batchnorm_valid_and_invalid_axes():
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论