提交 31e6600f authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Refactor the code to handle GpuAllocEmpty and inplace.

上级 3d081cca
......@@ -26,7 +26,7 @@ from .conv import GpuConv
# GpuDownsampleFactorMax, GpuDownsampleFactorMaxGrad
from .nnet import GpuSoftmax
from .opt import gpu_seqopt, register_opt, conv_groupopt, op_lifter
from .opt_util import alpha_merge, output_merge
from .opt_util import alpha_merge, output_merge, inplace_allocempty
......@@ -1242,49 +1242,25 @@ conv_groupopt.register('local_conv_dnn', local_conv_dnn, 20,
'conv_dnn', 'fast_compile', 'fast_run', 'cudnn')
@local_optimizer([GpuDnnConv], inplace=True)
def local_dnn_conv_inplace(node):
if type(node.op) != GpuDnnConv or node.op.inplace:
return
inputs = list(node.inputs)
dest = inputs[2]
if (dest.owner and
isinstance(dest.owner.op, GpuAllocEmpty) and
len(dest.clients) > 1):
inputs[2] = GpuAllocEmpty(dest.owner.op.dtype)(*dest.owner.inputs)
@inplace_allocempty(GpuDnnConv, 2)
def local_dnn_conv_inplace(node, inputs):
return [GpuDnnConv(algo=node.op.algo, inplace=True)(*inputs)]
@local_optimizer([GpuDnnConvGradW], inplace=True)
def local_dnn_convgw_inplace(node):
if type(node.op) != GpuDnnConvGradW or node.op.inplace:
return
inputs = list(node.inputs)
dest = inputs[2]
if (dest.owner and
isinstance(dest.owner.op, GpuAllocEmpty) and
len(dest.clients) > 1):
inputs[2] = GpuAllocEmpty(dest.owner.op.dtype)(*dest.owner.inputs)
@inplace_allocempty(GpuDnnConvGradW, 2)
def local_dnn_convgw_inplace(node, inputs):
return [GpuDnnConvGradW(algo=node.op.algo, inplace=True)(*inputs)]
@local_optimizer([GpuDnnConvGradI], inplace=True)
def local_dnn_convgi_inplace(node):
if type(node.op) != GpuDnnConvGradI or node.op.inplace:
return
inputs = list(node.inputs)
dest = inputs[2]
if (dest.owner and
isinstance(dest.owner.op, GpuAllocEmpty) and
len(dest.clients) > 1):
inputs[2] = GpuAllocEmpty(dest.owner.op.dtype)(*dest.owner.inputs)
@inplace_allocempty(GpuDnnConvGradI, 2)
def local_dnn_convgi_inplace(node, inputs):
return [GpuDnnConvGradI(algo=node.op.algo, inplace=True)(*inputs)]
optdb.register('local_dnna_conv_inplace',
tensor.opt.in2out(local_dnn_conv_inplace,
local_dnn_convgw_inplace,
local_dnn_convgi_inplace,
name="local_dnn_conv_inplace"),
name="local_dnna_conv_inplace"),
70.0, 'fast_run', 'inplace', 'gpuarray', 'cudnn')
......
......@@ -7,7 +7,7 @@ from theano.gof import local_optimizer
from theano.tensor import (DimShuffle, get_scalar_constant_value,
NotScalarConstantError)
from .basic_ops import GpuFromHost, HostFromGpu
from .basic_ops import GpuFromHost, HostFromGpu, GpuAllocEmpty
from .elemwise import GpuDimShuffle, GpuElemwise
_one = scal.constant(numpy.asarray(1.0, dtype='float64'))
......@@ -126,3 +126,22 @@ def output_merge(cls, alpha_in, beta_in, out_in, nd):
return maker(targ, *inputs)
return opt
return wrapper
def inplace_allocempty(op, idx):
def wrapper(maker):
@local_optimizer([op], inplace=True)
@wraps(maker)
def opt(node):
if type(node.op) != op or node.op.inplace:
return
inputs = list(node.inputs)
alloc = inputs[idx]
if (alloc.owner and
isinstance(alloc.owner.op, GpuAllocEmpty) and
len(alloc.clients) > 1):
alloc_op = GpuAllocEmpty(alloc.owner.op.dtype)
inputs[idx] = alloc_op(*alloc.owner.inputs)
return maker(node, inputs)
return opt
return wrapper
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论