提交 a0964ac0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2559 from abergeron/cudnn_v2_ab2

Allow dnn convolution to work inplace
import numpy import numpy
import theano import theano
from theano import Apply, tensor, scalar, Constant from theano import Apply, tensor, scalar
from theano.tensor import DimShuffle, discrete_dtypes from theano.tensor import discrete_dtypes
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
...@@ -12,7 +12,7 @@ if cuda_available: ...@@ -12,7 +12,7 @@ if cuda_available:
opt, GpuFromHost, opt, GpuFromHost,
HostFromGpu, host_from_gpu, HostFromGpu, host_from_gpu,
GpuDimShuffle) GpuDimShuffle)
from theano.sandbox.cuda.opt_util import alpha_merge, output_merge
class SparseBlockGemvSS(GpuOp): class SparseBlockGemvSS(GpuOp):
""" """
...@@ -645,80 +645,19 @@ if cuda_available: ...@@ -645,80 +645,19 @@ if cuda_available:
if node.op == sparse_block_outer_ss: if node.op == sparse_block_outer_ss:
return [sparse_block_outer_ss_inplace(*node.inputs)] return [sparse_block_outer_ss_inplace(*node.inputs)]
def grab_ger(v):
# We need to do some digging because apparently the
# cut_transfers op does not run before us.
if v.owner is not None:
if isinstance(v.owner.op, SparseBlockOuterSS):
return v.owner
elif (isinstance(v.owner.op, GpuFromHost) and
v.owner.inputs[0].owner is not None and
isinstance(v.owner.inputs[0].owner.op, HostFromGpu)):
return grab_ger(v.owner.inputs[0].owner.inputs[0])
else:
return None
# Should be run before elemwise fusion # Should be run before elemwise fusion
@opt.register_opt() @opt.register_opt()
@opt.local_optimizer([GpuElemwise]) @alpha_merge(SparseBlockOuterSS, alpha_in=5, nd=4)
def local_merge_blocksparse_alpha(node): def local_merge_blocksparse_alpha(node, *inputs):
""" """
GpuElemwise{mul}(lr, SparseBlockOuterSS) -> SparseBlockOuterSS(..., alpha=lr) GpuElemwise{mul}(lr, SparseBlockOuterSS) -> SparseBlockOuterSS(..., alpha=lr)
""" """
def grab_lr(v): return [sparse_block_outer_ss(*inputs)]
if v.owner is not None:
n = v.owner
if (isinstance(n.op, GpuDimShuffle) and
n.op.new_order == ('x', 'x', 'x', 'x')):
return host_from_gpu(n.inputs[0])
elif (isinstance(n.op, DimShuffle) and
n.op.new_order == ('x', 'x', 'x', 'x')):
return n.inputs[0]
elif isinstance(n.op, GpuFromHost):
return grab_lr(n.inputs[0])
else:
return None
else:
if (isinstance(v, Constant) and
v.broadcastable == (True, True, True, True)):
return v.dimshuffle(())
if (isinstance(node.op, GpuElemwise) and
node.op.scalar_op == scalar.mul and
node.nin == 2):
ger = grab_ger(node.inputs[0])
if ger is None:
ger = grab_ger(node.inputs[1])
lr = grab_lr(node.inputs[0])
else:
lr = grab_lr(node.inputs[1])
if lr is None or ger is None:
return None
alpha = lr * ger.inputs[5]
return [sparse_block_outer_ss(*(ger.inputs[:5] + [alpha]))]
@opt.register_opt() @opt.register_opt()
@opt.local_optimizer([GpuElemwise]) @output_merge(SparseBlockOuterSS, alpha_in=5, out_in=0, nd=4)
def local_merge_blocksparse_output(node): def local_merge_blocksparse_output(node, *inputs):
if (isinstance(node.op, GpuElemwise) and return [sparse_block_outer_ss(*inputs)]
(node.op.scalar_op == scalar.sub or
node.op.scalar_op == scalar.add) and
node.nin == 2):
ger = grab_ger(node.inputs[0])
W = node.inputs[1]
if ger is None:
ger = grab_ger(node.inputs[1])
W = node.inputs[0]
if ger is None:
return None
if node.op.scalar_op == scalar.sub:
alpha = -ger.inputs[5]
W = W - ger.inputs[0]
else:
alpha = ger.inputs[5]
W = W + ger.inputs[0]
return [sparse_block_outer_ss(*([W] + ger.inputs[1:5] +
[alpha]))]
def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx): def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx):
......
...@@ -103,11 +103,11 @@ cudnnConvolutionForward_v2( ...@@ -103,11 +103,11 @@ cudnnConvolutionForward_v2(
const cudnnTensorDescriptor_t destDesc, const cudnnTensorDescriptor_t destDesc,
void *destData) { void *destData) {
assert(*(float *)alpha == 1.0); assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 0.0); assert(*(float *)beta == 1.0);
return cudnnConvolutionForward(handle, srcDesc, srcData, return cudnnConvolutionForward(handle, srcDesc, srcData,
filterDesc, filterData, filterDesc, filterData,
convDesc, destDesc, destData, convDesc, destDesc, destData,
CUDNN_RESULT_NO_ACCUMULATE); CUDNN_RESULT_ACCUMULATE);
} }
#define cudnnConvolutionForward cudnnConvolutionForward_v2 #define cudnnConvolutionForward cudnnConvolutionForward_v2
...@@ -124,11 +124,11 @@ cudnnConvolutionBackwardFilter_v2( ...@@ -124,11 +124,11 @@ cudnnConvolutionBackwardFilter_v2(
const cudnnFilterDescriptor_t gradDesc, const cudnnFilterDescriptor_t gradDesc,
void *gradData) { void *gradData) {
assert(*(float *)alpha == 1.0); assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 0.0); assert(*(float *)beta == 1.0);
return cudnnConvolutionBackwardFilter(handle, srcDesc, srcData, return cudnnConvolutionBackwardFilter(handle, srcDesc, srcData,
diffDesc, diffData, diffDesc, diffData,
convDesc, gradDesc, gradData, convDesc, gradDesc, gradData,
CUDNN_RESULT_NO_ACCUMULATE); CUDNN_RESULT_ACCUMULATE);
} }
#define cudnnConvolutionBackwardFilter cudnnConvolutionBackwardFilter_v2 #define cudnnConvolutionBackwardFilter cudnnConvolutionBackwardFilter_v2
...@@ -146,7 +146,7 @@ cudnnConvolutionBackwardData_v2( ...@@ -146,7 +146,7 @@ cudnnConvolutionBackwardData_v2(
const cudnnTensorDescriptor_t gradDesc, const cudnnTensorDescriptor_t gradDesc,
void *gradData) { void *gradData) {
assert(*(float *)alpha == 1.0); assert(*(float *)alpha == 1.0);
assert(*(float *)beta == 0.0); assert(*(float *)beta == 1.0);
return cudnnConvolutionBackwardData(handle, return cudnnConvolutionBackwardData(handle,
(cudnnFilterDescriptor_t)filterDesc, (cudnnFilterDescriptor_t)filterDesc,
filterData, filterData,
...@@ -155,7 +155,7 @@ cudnnConvolutionBackwardData_v2( ...@@ -155,7 +155,7 @@ cudnnConvolutionBackwardData_v2(
(cudnnConvolutionDescriptor_t)convDesc, (cudnnConvolutionDescriptor_t)convDesc,
(cudnnTensorDescriptor_t)gradDesc, (cudnnTensorDescriptor_t)gradDesc,
gradData, gradData,
CUDNN_RESULT_NO_ACCUMULATE); CUDNN_RESULT_ACCUMULATE);
} }
#define cudnnConvolutionBackwardData cudnnConvolutionBackwardData_v2 #define cudnnConvolutionBackwardData cudnnConvolutionBackwardData_v2
......
import os import os
import numpy
import theano import theano
from theano import Apply, gof, tensor, config from theano import Apply, gof, tensor, config, Variable
from theano.scalar import as_scalar from theano.scalar import as_scalar, constant
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType, grad_not_implemented
from theano.gof import Optimizer, local_optimizer, COp from theano.gof import Optimizer, local_optimizer, COp
from theano.gof.type import CDataType, Generic from theano.gof.type import CDataType, Generic
from theano.compat import PY3 from theano.compat import PY3
...@@ -18,10 +19,12 @@ from theano.sandbox.cuda import GpuOp ...@@ -18,10 +19,12 @@ from theano.sandbox.cuda import GpuOp
from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable,
host_from_gpu, host_from_gpu,
gpu_contiguous, HostFromGpu, gpu_contiguous, HostFromGpu,
cp_on_negative_strides) cp_on_negative_strides,
gpu_alloc)
from theano.sandbox.cuda.blas import (GpuConv, GpuDownsampleFactorMax, from theano.sandbox.cuda.blas import (GpuConv, GpuDownsampleFactorMax,
GpuDownsampleFactorMaxGrad) GpuDownsampleFactorMaxGrad)
from theano.sandbox.cuda.nnet import GpuSoftmax from theano.sandbox.cuda.nnet import GpuSoftmax
from theano.sandbox.cuda.opt_util import alpha_merge, output_merge
from theano.sandbox.cuda import gpu_seqopt, register_opt from theano.sandbox.cuda import gpu_seqopt, register_opt
from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler
...@@ -340,6 +343,25 @@ AddConfigVar('dnn.conv.workmem', ...@@ -340,6 +343,25 @@ AddConfigVar('dnn.conv.workmem',
EnumStr('small', 'none', 'large'), EnumStr('small', 'none', 'large'),
in_c_key=False) in_c_key=False)
# scalar constants
_zero = constant(numpy.asarray(0.0, dtype='float32'))
_one = constant(numpy.asarray(1.0, dtype='float32'))
def ensure_float(val, default, name):
if val is None:
return default.clone()
if not isinstance(val, Variable):
val = constant(val)
if hasattr(val, 'ndim') and val.ndim == 0:
val = as_scalar(val)
if not isinstance(val.type, theano.scalar.Scalar):
raise TypeError("%s: expected a scalar value" % (name,))
if not val.type.dtype == 'float32':
raise TypeError("%s: type is not float32" % (name,))
return val
class GpuDnnConv(DnnBase, COp): class GpuDnnConv(DnnBase, COp):
""" """
The forward convolution. The forward convolution.
...@@ -348,9 +370,9 @@ class GpuDnnConv(DnnBase, COp): ...@@ -348,9 +370,9 @@ class GpuDnnConv(DnnBase, COp):
:param kernel: :param kernel:
:param descr: the convolution descriptor :param descr: the convolution descriptor
""" """
__props__ = ('workmem',) __props__ = ('workmem', 'inplace')
def __init__(self, workmem=None): def __init__(self, workmem=None, inplace=False):
""" """
:param workmem: either 'none', 'small' or 'large'. Default is :param workmem: either 'none', 'small' or 'large'. Default is
the value of :attr:`config.dnn.conv.workmem`. the value of :attr:`config.dnn.conv.workmem`.
...@@ -360,88 +382,105 @@ class GpuDnnConv(DnnBase, COp): ...@@ -360,88 +382,105 @@ class GpuDnnConv(DnnBase, COp):
if workmem is None: if workmem is None:
workmem = config.dnn.conv.workmem workmem = config.dnn.conv.workmem
self.workmem = workmem self.workmem = workmem
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [2]}
assert self.workmem in ['none', 'small', 'large'] assert self.workmem in ['none', 'small', 'large']
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
if not hasattr(self, 'workmem'): if not hasattr(self, 'workmem'):
self.workmem = 'small' self.workmem = 'none'
if not hasattr(self, 'inplace'):
self.inplace = False
def get_op_params(self): def get_op_params(self):
if self.inplace:
inpl_def = [('CONV_INPLACE', '1')]
else:
inpl_def = []
if version() == -1: if version() == -1:
return [('CONV_ALGO', "0")] alg_def = ('CONV_ALGO', "0")
else:
if self.workmem == 'none': if self.workmem == 'none':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM' alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM'
elif self.workmem == 'small': elif self.workmem == 'small':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM' alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
elif self.workmem == 'large': elif self.workmem == 'large':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM' alg = 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM'
return [('CONV_ALGO', alg)] alg_def = ('CONV_ALGO', alg)
return [alg_def] + inpl_def
def make_node(self, img, kern, desc): def make_node(self, img, kern, output, desc, alpha=None):
img = as_cuda_ndarray_variable(img) img = as_cuda_ndarray_variable(img)
kern = as_cuda_ndarray_variable(kern) kern = as_cuda_ndarray_variable(kern)
output = as_cuda_ndarray_variable(output)
if img.type.ndim != 4: if img.type.ndim != 4:
raise TypeError('img must be 4D tensor') raise TypeError('img must be 4D tensor')
if kern.type.ndim != 4: if kern.type.ndim != 4:
raise TypeError('kern must be 4D tensor') raise TypeError('kern must be 4D tensor')
if output.type.ndim != 4:
raise TypeError('output must be a 4D tensor')
if not isinstance(desc.type, CDataType) \ if not isinstance(desc.type, CDataType) \
or desc.type.ctype != 'cudnnConvolutionDescriptor_t': or desc.type.ctype != 'cudnnConvolutionDescriptor_t':
raise TypeError('desc must be cudnnConvolutionDescriptor_t') raise TypeError('desc must be cudnnConvolutionDescriptor_t')
broadcastable = (img.type.broadcastable[0], alpha = ensure_float(alpha, _one, 'alpha')
kern.type.broadcastable[0],
False, False) return Apply(self, [img, kern, output, desc, alpha],
return Apply(self, [img, kern, desc], [output.type()])
[CudaNdarrayType(broadcastable)()])
def grad(self, inp, grads): def grad(self, inp, grads):
img, kerns, desc = inp img, kerns, output, desc, alpha = inp
top, = grads top, = grads
top = cp_on_negative_strides(top) top = cp_on_negative_strides(top)
d_img = GpuDnnConvGradI()(kerns, top, desc, d_img = GpuDnnConvGradI()(kerns, top, img.zeros_like(), desc)
img.shape[2], img.shape[3]) d_kerns = GpuDnnConvGradW()(img, top, kerns.zeros_like(), desc)
d_kerns = GpuDnnConvGradW()(img, top, desc, d_alpha = grad_not_implemented(self, 4, alpha)
kerns.shape[2], kerns.shape[3])
return d_img, d_kerns, theano.gradient.DisconnectedType()() return [d_img, d_kerns, top * alpha, DisconnectedType()(), d_alpha]
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc # not connected to desc
return [[1], [1], [0]] return [[1], [1], [1], [0], [1]]
def infer_shape(self, node, shape): @staticmethod
b = shape[0][0] # Number of inputs def get_out_shape(ishape, kshape, border_mode, subsample):
h = shape[0][2] # Height of input feature maps """
w = shape[0][3] # Width of input feature maps This function computes the output shape for a convolution with
nb = shape[1][0] # Number of output feature maps the specified parameters. `ishape` and `kshape` can be symbolic
kh = shape[1][2] # Height of each filter or scalar.
kw = shape[1][3] # Width of each filter """
padh = 0 b = ishape[0] # Number of inputs
padw = 0 h = ishape[2] # Height of input feature maps
if ( w = ishape[3] # Width of input feature maps
not node.inputs[2].owner nb = kshape[0] # Number of output feature maps
or not isinstance(node.inputs[2].owner.op, GpuDnnConvDesc) kh = kshape[2] # Height of each filter
): kw = kshape[3] # Width of each filter
raise theano.tensor.basic.ShareError("case not implemented and probably not needed")
desc = node.inputs[2].owner.op sh, sw = subsample
sh, sw = desc.subsample if border_mode == 'full':
if desc.border_mode == 'full':
padh = kh - 1 padh = kh - 1
padw = kw - 1 padw = kw - 1
elif isinstance(desc.border_mode, tuple): elif isinstance(border_mode, tuple):
padh, padw = desc.border_mode padh, padw = border_mode
else: else:
assert desc.border_mode == 'valid' assert border_mode == 'valid'
padh = 0
padw = 0
return [( return (
b, nb, b, nb,
(h + 2*padh - kh)//sh + 1, (h + 2*padh - kh)//sh + 1,
(w + 2*padw - kw)//sw + 1 (w + 2*padw - kw)//sw + 1
)] )
def infer_shape(self, node, shape):
return [shape[2]]
class GpuDnnConvGradW(DnnBase, COp): class GpuDnnConvGradW(DnnBase, COp):
...@@ -453,58 +492,64 @@ class GpuDnnConvGradW(DnnBase, COp): ...@@ -453,58 +492,64 @@ class GpuDnnConvGradW(DnnBase, COp):
:param descr: the convolution descriptor :param descr: the convolution descriptor
""" """
__props__ = () __props__ = ('inplace',)
def __init__(self): def __init__(self, inplace=False):
COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_gw.c"], COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_gw.c"],
"APPLY_SPECIFIC(conv_gw)") "APPLY_SPECIFIC(conv_gw)")
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [2]}
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'inplace'):
self.inplace = False
def grad(self, inp, grads): def grad(self, inp, grads):
img, top, desc, h, w = inp img, top, output, desc, alpha = inp
kerns, = grads kerns, = grads
kerns = gpu_contiguous(kerns) kerns = gpu_contiguous(kerns)
d_img = GpuDnnConvGradI()(kerns, top, desc, d_img = GpuDnnConvGradI()(kerns, top, img.zeros_like(), desc)
img.shape[2], img.shape[3]) d_top = GpuDnnConv()(img, kerns, top.zeros_like(), desc)
d_top = GpuDnnConv()(img, kerns, desc) d_alpha = grad_not_implemented(self, 4, alpha)
return (d_img, d_top, DisconnectedType()(), DisconnectedType()(), return (d_img, d_top, kerns * alpha, DisconnectedType()(), d_alpha)
DisconnectedType()())
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc, h, w # not connected to desc
return [[1], [1], [0], [0], [0]] return [[1], [1], [1], [0], [1]]
def get_op_params(self):
if self.inplace:
return [('CONV_INPLACE', '1')]
else:
return []
def make_node(self, img, topgrad, desc, h, w): def make_node(self, img, topgrad, output, desc, alpha=None):
img = as_cuda_ndarray_variable(img) img = as_cuda_ndarray_variable(img)
topgrad = as_cuda_ndarray_variable(topgrad) topgrad = as_cuda_ndarray_variable(topgrad)
output = as_cuda_ndarray_variable(output)
if img.type.ndim != 4: if img.type.ndim != 4:
raise TypeError('img must be 4D tensor') raise TypeError('img must be 4D tensor')
if topgrad.type.ndim != 4: if topgrad.type.ndim != 4:
raise TypeError('topgrad must be 4D tensor') raise TypeError('topgrad must be 4D tensor')
if output.type.ndim != 4:
raise TypeError('output must be 4D tensor')
if not isinstance(desc.type, CDataType) \ if not isinstance(desc.type, CDataType) \
or desc.type.ctype != 'cudnnConvolutionDescriptor_t': or desc.type.ctype != 'cudnnConvolutionDescriptor_t':
raise TypeError('desc must be cudnnConvolutionDescriptor_t') raise TypeError('desc must be cudnnConvolutionDescriptor_t')
h = as_scalar(h) alpha = ensure_float(alpha, _one, 'alpha')
w = as_scalar(w)
broadcastable = [topgrad.type.broadcastable[1],
img.type.broadcastable[1],
False, False]
return Apply(self, [img, topgrad, desc, h, w], return Apply(self, [img, topgrad, output, desc, alpha],
[CudaNdarrayType(broadcastable)()]) [output.type()])
def infer_shape(self, node, shape): def infer_shape(self, node, shape):
return [( return [shape[2]]
shape[1][1],
shape[0][1],
node.inputs[3],
node.inputs[4]
)]
class GpuDnnConvGradI(DnnBase, COp): class GpuDnnConvGradI(DnnBase, COp):
...@@ -516,57 +561,59 @@ class GpuDnnConvGradI(DnnBase, COp): ...@@ -516,57 +561,59 @@ class GpuDnnConvGradI(DnnBase, COp):
:param descr: the convolution descriptor :param descr: the convolution descriptor
""" """
__props__ = () __props__ = ('inplace',)
def __init__(self): def __init__(self, inplace=False):
COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_gi.c"], COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_gi.c"],
"APPLY_SPECIFIC(conv_gi)") "APPLY_SPECIFIC(conv_gi)")
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [2]}
def grad(self, inp, grads): def grad(self, inp, grads):
kerns, top, desc, h, w = inp kerns, top, output, desc, alpha = inp
img, = grads img, = grads
img = cp_on_negative_strides(img) img = cp_on_negative_strides(img)
d_kerns = GpuDnnConvGradW()(img, top, desc, d_kerns = GpuDnnConvGradW()(img, top, kerns.zeros_like(), desc)
kerns.shape[2], kerns.shape[3]) d_top = GpuDnnConv()(img, kerns, top.zeros_like(), desc)
d_top = GpuDnnConv()(img, kerns, desc) d_alpha = grad_not_implemented(self, 4, alpha)
return (d_kerns, d_top, DisconnectedType()(), DisconnectedType()(),
DisconnectedType()()) return (d_kerns, d_top, img * alpha, DisconnectedType()(), d_alpha)
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc, h, w # not connected to desc
return [[1], [1], [0], [0], [0]] return [[1], [1], [1], [0], [1]]
def get_op_params(self):
if self.inplace:
return [('CONV_INPLACE', '1')]
else:
return []
def make_node(self, kern, topgrad, desc, h, w): def make_node(self, kern, topgrad, output, desc, alpha=None):
kern = as_cuda_ndarray_variable(kern) kern = as_cuda_ndarray_variable(kern)
topgrad = as_cuda_ndarray_variable(topgrad) topgrad = as_cuda_ndarray_variable(topgrad)
output = as_cuda_ndarray_variable(output)
if kern.type.ndim != 4: if kern.type.ndim != 4:
raise TypeError('kern must be 4D tensor') raise TypeError('kern must be 4D tensor')
if topgrad.type.ndim != 4: if topgrad.type.ndim != 4:
raise TypeError('topgrad must be 4D tensor') raise TypeError('topgrad must be 4D tensor')
if output.type.ndim != 4:
raise TypeError('output must be 4D tensor')
if not isinstance(desc.type, CDataType) \ if not isinstance(desc.type, CDataType) \
or desc.type.ctype != 'cudnnConvolutionDescriptor_t': or desc.type.ctype != 'cudnnConvolutionDescriptor_t':
raise TypeError('desc must be cudnnConvolutionDescriptor_t') raise TypeError('desc must be cudnnConvolutionDescriptor_t')
h = as_scalar(h) alpha = ensure_float(alpha, _one, 'alpha')
w = as_scalar(w)
broadcastable = [topgrad.type.broadcastable[0],
kern.type.broadcastable[1],
False, False]
return Apply(self, [kern, topgrad, desc, h, w], return Apply(self, [kern, topgrad, output, desc, alpha],
[CudaNdarrayType(broadcastable)()]) [output.type()])
def infer_shape(self, node, shape): def infer_shape(self, node, shape):
return [( return [shape[2]]
shape[1][0],
shape[0][1],
node.inputs[3],
node.inputs[4]
)]
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
...@@ -595,32 +642,31 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -595,32 +642,31 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
:param workmem: Specify the amount of working memory allowed. :param workmem: Specify the amount of working memory allowed.
More memory is usually faster. One of 'none', 'small' or More memory is usually faster. One of 'none', 'small' or
'large'. (default is None which takes its value from 'large'. (default is None which takes its value from
config.dnn.conv.workmem) :attr:`config.dnn.conv.workmem`)
:warning: The cuDNN library only works with GPU that have a compute :warning: The cuDNN library only works with GPU that have a compute
capability of 3.0 or higer. This means that older GPU will not capability of 3.0 or higer. This means that older GPU will not
work with this Op. work with this Op.
:note: The working memory of the op is influenced by
:attr:`config.dnn.conv.workmem`.
""" """
fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None) fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None)
if (border_mode == 'valid' and subsample == (1,1) and if (border_mode == 'valid' and subsample == (1,1) and
direction_hint == 'bprop weights'): direction_hint == 'bprop weights'):
# Special case: We are asked to use GpuDnnConvGradW. We need to set # Special case: We are asked to use GpuDnnConvGradW. We need to set
# up a suitable 'fake' convolution to compute the gradient for. # up a suitable 'fake' convolution to compute the gradient for.
img = gpu_contiguous(img.dimshuffle(1, 0, 2, 3)) img = cp_on_negative_strides(img.dimshuffle(1, 0, 2, 3))
if conv_mode == 'conv': if conv_mode == 'conv':
# We need to flip manually. These 'kerns' are not the kernels # We need to flip manually. These 'kerns' are not the kernels
# that would be flipped by conv_mode='conv' in GpuDnnConvGradW. # that would be flipped by conv_mode='conv' in GpuDnnConvGradW.
kerns = kerns[:, :, ::-1, ::-1] kerns = kerns[:, :, ::-1, ::-1]
kerns = gpu_contiguous(kerns.dimshuffle(1, 0, 2, 3)) kerns = gpu_contiguous(kerns.dimshuffle(1, 0, 2, 3))
shape = theano.tensor.stack(kerns.shape[1], img.shape[1], shape2 = shape_i(img, 2, fgraph) - shape_i(kerns, 2, fgraph) + 1
img.shape[2] - kerns.shape[2] + 1, shape3 = shape_i(img, 3, fgraph) - shape_i(kerns, 3, fgraph) + 1
img.shape[3] - kerns.shape[3] + 1) out = gpu_alloc(_zero.clone(), shape_i(kerns, 1, fgraph),
shape_i(img, 1, fgraph), shape2, shape3)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1), desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode='cross')(img.shape, shape) conv_mode='cross')(img.shape, out.shape)
conv = GpuDnnConvGradW()(img, kerns, desc, shape[2], shape[3]) conv = GpuDnnConvGradW()(img, kerns, out, desc)
return as_cuda_ndarray_variable(conv.dimshuffle(1, 0, 2, 3)) return as_cuda_ndarray_variable(conv.dimshuffle(1, 0, 2, 3))
elif (border_mode == 'full' and subsample == (1, 1) and elif (border_mode == 'full' and subsample == (1, 1) and
...@@ -628,17 +674,16 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -628,17 +674,16 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
# Special case: We can be faster by using GpuDnnConvGradI to compute # Special case: We can be faster by using GpuDnnConvGradI to compute
# the full convolution as the backward pass of a valid convolution. # the full convolution as the backward pass of a valid convolution.
# We just need to set up a suitable 'fake' valid convolution. # We just need to set up a suitable 'fake' valid convolution.
img = gpu_contiguous(img) img = cp_on_negative_strides(img)
kerns = gpu_contiguous(kerns.dimshuffle(1, 0, 2, 3)) kerns = gpu_contiguous(kerns.dimshuffle(1, 0, 2, 3))
conv_mode = 'cross' if conv_mode == 'conv' else 'conv' conv_mode = 'cross' if conv_mode == 'conv' else 'conv'
shape2 = shape_i(img, 2, fgraph) + shape_i(kerns, 2, fgraph) - 1 shape2 = shape_i(img, 2, fgraph) + shape_i(kerns, 2, fgraph) - 1
shape3 = shape_i(img, 3, fgraph) + shape_i(kerns, 3, fgraph) - 1 shape3 = shape_i(img, 3, fgraph) + shape_i(kerns, 3, fgraph) - 1
shape = theano.tensor.stack(shape_i(img, 0, fgraph), out = gpu_alloc(_zero.clone(), shape_i(img, 0, fgraph),
shape_i(kerns, 1, fgraph), shape_i(kerns, 1, fgraph), shape2, shape3)
shape2, shape3)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1), desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode=conv_mode)(shape, kerns.shape) conv_mode=conv_mode)(out.shape, kerns.shape)
return GpuDnnConvGradI()(kerns, img, desc, shape2, shape3) return GpuDnnConvGradI()(kerns, img, out, desc)
# Standard case: We use GpuDnnConv with suitable padding. # Standard case: We use GpuDnnConv with suitable padding.
# cp_on_negative_strides will return a gpu_contiguous copy # cp_on_negative_strides will return a gpu_contiguous copy
...@@ -653,7 +698,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -653,7 +698,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
# algorithm. # algorithm.
if workmem is None or workmem == 'small': if workmem is None or workmem == 'small':
workmem = 'none' workmem = 'none'
return GpuDnnConv(workmem=workmem)(img, kerns, desc) out_shp = GpuDnnConv.get_out_shape(img.shape, kerns.shape, border_mode,
subsample)
out = gpu_alloc(_zero.clone(),
out_shp[0], out_shp[1],
out_shp[2], out_shp[3])
return GpuDnnConv(workmem=workmem)(img, kerns, out, desc)
class GpuDnnPoolDesc(GpuOp): class GpuDnnPoolDesc(GpuOp):
...@@ -1471,6 +1521,69 @@ if True: ...@@ -1471,6 +1521,69 @@ if True:
rval, node.outputs[0].type.broadcastable) rval, node.outputs[0].type.broadcastable)
return [rval] return [rval]
@register_opt('cudnn')
@local_optimizer([GpuDnnConv], inplace=True)
def local_dnn_conv_inplace(node):
if type(node.op) != GpuDnnConv or node.op.inplace == True:
return
return [GpuDnnConv(workmem=node.op.workmem, inplace=True)(*node.inputs)]
@register_opt('cudnn')
@local_optimizer([GpuDnnConvGradW], inplace=True)
def local_dnn_convgw_inplace(node):
if type(node.op) != GpuDnnConvGradW or node.op.inplace == True:
return
return [GpuDnnConvGradW(inplace=True)(*node.inputs)]
@register_opt('cudnn')
@local_optimizer([GpuDnnConvGradI], inplace=True)
def local_dnn_convgi_inplace(node):
if type(node.op) != GpuDnnConvGradI or node.op.inplace == True:
return
return [GpuDnnConvGradI(inplace=True)(*node.inputs)]
@register_opt('cudnn')
@alpha_merge(GpuDnnConv, alpha_in=4, nd=4)
def local_dnn_conv_alpha_merge(node, *inputs):
if version() == -1:
return None
return [GpuDnnConv(workmem=node.op.workmem)(*inputs)]
@register_opt('cudnn')
@alpha_merge(GpuDnnConvGradW, alpha_in=4, nd=4)
def local_dnn_convw_alpha_merge(node, *inputs):
if version() == -1:
return None
return [GpuDnnConvGradW()(*inputs)]
@register_opt('cudnn')
@alpha_merge(GpuDnnConvGradI, alpha_in=4, nd=4)
def local_dnn_convi_alpha_merge(node, *inputs):
if version() == -1:
return None
return [GpuDnnConvGradI()(*inputs)]
@register_opt('cudnn')
@output_merge(GpuDnnConv, alpha_in=4, out_in=2, nd=4)
def local_dnn_conv_output_merge(node, *inputs):
if version() == -1:
return None
return [GpuDnnConv(workmem=node.op.workmem)(*inputs)]
@register_opt('cudnn')
@output_merge(GpuDnnConvGradW, alpha_in=4, out_in=2, nd=4)
def local_dnn_convw_output_merge(node, *inputs):
if version() == -1:
return None
return [GpuDnnConvGradW()(*inputs)]
@register_opt('cudnn')
@output_merge(GpuDnnConvGradI, alpha_in=4, out_in=2, nd=4)
def local_dnn_convi_output_merge(node, *inputs):
if version() == -1:
return None
return [GpuDnnConvGradI()(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@local_optimizer([GpuDownsampleFactorMax]) @local_optimizer([GpuDownsampleFactorMax])
def local_pool_dnn(node): def local_pool_dnn(node):
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
int int
APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns, APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
cudnnConvolutionDescriptor_t desc, CudaNdarray *om, cudnnConvolutionDescriptor_t desc,
CudaNdarray **output) { float alpha, CudaNdarray **output) {
cudnnStatus_t err = CUDNN_STATUS_SUCCESS; cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
if (c_set_tensor4d(input, APPLY_SPECIFIC(input)) == -1) if (c_set_tensor4d(input, APPLY_SPECIFIC(input)) == -1)
...@@ -11,23 +11,16 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns, ...@@ -11,23 +11,16 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
if (c_set_filter(kerns, APPLY_SPECIFIC(kerns)) == -1) if (c_set_filter(kerns, APPLY_SPECIFIC(kerns)) == -1)
return 1; return 1;
{ #ifdef CONV_INPLACE
int out_dims[4]; Py_XDECREF(*output);
err = cudnnGetConvolution2dForwardOutputDim( *output = om;
desc, Py_INCREF(*output);
APPLY_SPECIFIC(input), #else
APPLY_SPECIFIC(kerns), if (CudaNdarray_prep_output(output, 4, CudaNdarray_HOST_DIMS(om)) != 0)
&out_dims[0], &out_dims[1], &out_dims[2], &out_dims[3]);
if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"GpuDnnConv: error while computing the output shape: %s",
cudnnGetErrorString(err));
return 1; return 1;
} if (CudaNdarray_CopyFromCudaNdarray(*output, om))
if (CudaNdarray_prep_output(output, 4, out_dims) != 0) {
return 1; return 1;
} #endif
}
if (c_set_tensor4d(*output, APPLY_SPECIFIC(output)) == -1) if (c_set_tensor4d(*output, APPLY_SPECIFIC(output)) == -1)
return 1; return 1;
...@@ -54,8 +47,7 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns, ...@@ -54,8 +47,7 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
if (workspace == NULL && worksize != 0) if (workspace == NULL && worksize != 0)
return 1; return 1;
const float alpha = 1; const float beta = 1;
const float beta = 0;
err = cudnnConvolutionForward( err = cudnnConvolutionForward(
_handle, _handle,
......
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
int int
APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output, APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
cudnnConvolutionDescriptor_t desc, CudaNdarray *im, cudnnConvolutionDescriptor_t desc,
int h, int w, float alpha, CudaNdarray **input) {
CudaNdarray **input) {
cudnnStatus_t err = CUDNN_STATUS_SUCCESS; cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
if (c_set_tensor4d(output, APPLY_SPECIFIC(output)) == -1) if (c_set_tensor4d(output, APPLY_SPECIFIC(output)) == -1)
...@@ -12,23 +11,21 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output, ...@@ -12,23 +11,21 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
if (c_set_filter(kerns, APPLY_SPECIFIC(kerns)) == -1) if (c_set_filter(kerns, APPLY_SPECIFIC(kerns)) == -1)
return 1; return 1;
{ #ifdef CONV_INPLACE
int out_dims[4]; Py_XDECREF(*input);
out_dims[0] = CudaNdarray_HOST_DIMS(output)[0]; *input = im;
out_dims[1] = CudaNdarray_HOST_DIMS(kerns)[1]; Py_INCREF(*input);
out_dims[2] = h; #else
out_dims[3] = w; if (CudaNdarray_prep_output(input, 4, CudaNdarray_HOST_DIMS(im)) != 0)
if (CudaNdarray_prep_output(input, 4, out_dims) != 0) {
return 1; return 1;
} if (CudaNdarray_CopyFromCudaNdarray(*input, im))
} return 1;
#endif
if (c_set_tensor4d(*input, APPLY_SPECIFIC(input)) == -1) if (c_set_tensor4d(*input, APPLY_SPECIFIC(input)) == -1)
return 1; return 1;
{ const float beta = 1;
const float alpha = 1;
const float beta = 0;
err = cudnnConvolutionBackwardData( err = cudnnConvolutionBackwardData(
_handle, _handle,
...@@ -38,7 +35,6 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output, ...@@ -38,7 +35,6 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
desc, desc,
(void *)&beta, (void *)&beta,
APPLY_SPECIFIC(input), CudaNdarray_DEV_DATA(*input)); APPLY_SPECIFIC(input), CudaNdarray_DEV_DATA(*input));
}
if (err != CUDNN_STATUS_SUCCESS) { if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, "GpuDnnConvGradI: error doing operation: %s", PyErr_Format(PyExc_RuntimeError, "GpuDnnConvGradI: error doing operation: %s",
cudnnGetErrorString(err)); cudnnGetErrorString(err));
......
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
int int
APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output, APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
cudnnConvolutionDescriptor_t desc, CudaNdarray *km, cudnnConvolutionDescriptor_t desc,
int h, int w, float alpha, CudaNdarray **kerns) {
CudaNdarray **kerns) {
cudnnStatus_t err = CUDNN_STATUS_SUCCESS; cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
if (c_set_tensor4d(input, APPLY_SPECIFIC(input)) == -1) if (c_set_tensor4d(input, APPLY_SPECIFIC(input)) == -1)
...@@ -12,23 +11,21 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output, ...@@ -12,23 +11,21 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
if (c_set_tensor4d(output, APPLY_SPECIFIC(output)) == -1) if (c_set_tensor4d(output, APPLY_SPECIFIC(output)) == -1)
return 1; return 1;
{ #ifdef CONV_INPLACE
int out_dims[4]; Py_XDECREF(*kerns);
out_dims[0] = CudaNdarray_HOST_DIMS(output)[1]; *kerns = km;
out_dims[1] = CudaNdarray_HOST_DIMS(input)[1]; Py_INCREF(*kerns);
out_dims[2] = h; #else
out_dims[3] = w; if (CudaNdarray_prep_output(kerns, 4, CudaNdarray_HOST_DIMS(km)) != 0)
if (CudaNdarray_prep_output(kerns, 4, out_dims) != 0) {
return 1; return 1;
} if (CudaNdarray_CopyFromCudaNdarray(*kerns, km))
} return 1;
#endif
if (c_set_filter(*kerns, APPLY_SPECIFIC(kerns)) == -1) if (c_set_filter(*kerns, APPLY_SPECIFIC(kerns)) == -1)
return 1; return 1;
{ const float beta = 1;
const float alpha = 1;
const float beta = 0;
err = cudnnConvolutionBackwardFilter( err = cudnnConvolutionBackwardFilter(
_handle, _handle,
...@@ -38,7 +35,6 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output, ...@@ -38,7 +35,6 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
desc, desc,
(void *)&beta, (void *)&beta,
APPLY_SPECIFIC(kerns), CudaNdarray_DEV_DATA(*kerns)); APPLY_SPECIFIC(kerns), CudaNdarray_DEV_DATA(*kerns));
}
if (err != CUDNN_STATUS_SUCCESS) { if (err != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, "GpuDnnConvGradW: error doing operation: %s", PyErr_Format(PyExc_RuntimeError, "GpuDnnConvGradW: error doing operation: %s",
cudnnGetErrorString(err)); cudnnGetErrorString(err));
......
...@@ -88,6 +88,7 @@ register_opt()(theano.tensor.opt.local_track_shape_i) ...@@ -88,6 +88,7 @@ register_opt()(theano.tensor.opt.local_track_shape_i)
register_opt(name='gpu_constant_folding')( register_opt(name='gpu_constant_folding')(
tensor.opt.constant_folding) tensor.opt.constant_folding)
# This is a partial list of CPU ops that can be in some circonstance # This is a partial list of CPU ops that can be in some circonstance
# moved to the GPU. This list is used by an optimization. # moved to the GPU. This list is used by an optimization.
# Hopefully, we can keep this list up to date. # Hopefully, we can keep this list up to date.
......
from functools import wraps
import numpy
import theano
from theano import scalar as scal, Constant
from theano.gof import local_optimizer
from theano.tensor import DimShuffle
from theano.sandbox.cuda.basic_ops import (
GpuFromHost, HostFromGpu, GpuDimShuffle, GpuElemwise)
def grab_cpu_scalar(v, nd):
if v.owner is not None:
n = v.owner
if (isinstance(n.op, GpuDimShuffle) and
n.op.new_order == ('x',) * nd):
return host_from_gpu(n.inputs[0])
elif (isinstance(n.op, DimShuffle) and
n.op.new_order == ('x',) * nd):
return n.inputs[0]
elif isinstance(n.op, GpuFromHost):
return grab_cpu_scalar(n.inputs[0], nd=nd)
else:
return None
else:
if (isinstance(v, Constant) and
v.broadcastable == (True,) * nd):
return v.dimshuffle(())
def find_node(v, cls):
# This digs through possibly redundant transfers to for the node
# that has the op class specified.
if v.owner is not None:
if isinstance(v.owner.op, cls):
return v.owner
elif (isinstance(v.owner.op, GpuFromHost) and
v.owner.inputs[0].owner is not None and
isinstance(v.owner.inputs[0].owner.op, HostFromGpu)):
return find_node(v.owner.inputs[0].owner.inputs[0], cls)
else:
return None
def alpha_merge(cls, alpha_in, nd):
def wrapper(maker):
@local_optimizer([GpuElemwise])
@wraps(maker)
def opt(node):
if (isinstance(node.op, GpuElemwise) and
node.op.scalar_op == scal.mul and
node.nin == 2):
targ = find_node(node.inputs[0], cls)
if targ is None:
targ = find_node(node.inputs[1], cls)
lr = grab_cpu_scalar(node.inputs[0], nd=nd)
else:
lr = grab_cpu_scalar(node.inputs[1], nd=nd)
if lr is None or targ is None:
return None
inputs = list(targ.inputs)
inputs[alpha_in] = lr * targ.inputs[alpha_in]
return maker(targ, *inputs)
return opt
return wrapper
def output_merge(cls, alpha_in, out_in, nd):
def wrapper(maker):
@local_optimizer([GpuElemwise])
@wraps(maker)
def opt(node):
if (isinstance(node.op, GpuElemwise) and
(node.op.scalar_op == scal.sub or
node.op.scalar_op == scal.add) and
node.nin == 2):
targ = find_node(node.inputs[0], cls)
W = node.inputs[1]
if targ is None:
targ = find_node(node.inputs[1], cls)
W = node.inputs[0]
if targ is None:
return None
if node.op.scalar_op == scal.sub:
alpha = -targ.inputs[alpha_in]
W = W - targ.inputs[out_in]
else:
alpha = targ.inputs[alpha_in]
W = W + targ.inputs[out_in]
inputs = list(targ.inputs)
inputs[out_in] = W
inputs[alpha_in] = alpha
return maker(targ, *inputs)
return opt
return wrapper
...@@ -18,7 +18,8 @@ from theano.sandbox.cuda.basic_ops import (GpuDimShuffle, ...@@ -18,7 +18,8 @@ from theano.sandbox.cuda.basic_ops import (GpuDimShuffle,
from theano.sandbox.cuda.blocksparse import (sparse_block_dot_SS, from theano.sandbox.cuda.blocksparse import (sparse_block_dot_SS,
sparse_block_gemv_ss, sparse_block_gemv_ss,
sparse_block_outer_ss, sparse_block_outer_ss,
sparse_block_outer_ss_inplace) sparse_block_outer_ss_inplace,
SparseBlockOuterSS)
from theano.sandbox.cuda.var import float32_shared_constructor from theano.sandbox.cuda.var import float32_shared_constructor
...@@ -186,13 +187,20 @@ def test_blocksparse_grad_merge(): ...@@ -186,13 +187,20 @@ def test_blocksparse_grad_merge():
f1 = theano.function([h, iIdx, b, oIdx], updates=[(W, upd)], f1 = theano.function([h, iIdx, b, oIdx], updates=[(W, upd)],
mode=mode_with_gpu) mode=mode_with_gpu)
# not running with mode=gpu ensures that the elemwise is not merged in
mode = None # Make sure the lr update was merged.
if theano.config.mode == 'FAST_COMPILE': assert isinstance(f1.maker.fgraph.outputs[0].owner.op, SparseBlockOuterSS)
mode = theano.compile.mode.get_mode('FAST_RUN')
# Exclude the merge optimizations.
mode = mode_with_gpu.excluding('local_merge_blocksparse_alpha')
mode = mode.excluding('local_merge_blocksparse_output')
f2 = theano.function([h, iIdx, b, oIdx], updates=[(W, upd)], mode=mode) f2 = theano.function([h, iIdx, b, oIdx], updates=[(W, upd)], mode=mode)
# Make sure the lr update is not merged.
assert not isinstance(f2.maker.fgraph.outputs[0].owner.op,
SparseBlockOuterSS)
f2(h_val, iIdx_val, b_val, oIdx_val) f2(h_val, iIdx_val, b_val, oIdx_val)
W_ref = W.get_value() W_ref = W.get_value()
......
...@@ -260,12 +260,13 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -260,12 +260,13 @@ class TestDnnInferShapes(utt.InferShapeTester):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img = T.ftensor4('img') img = T.ftensor4('img')
kerns = T.ftensor4('kerns') kerns = T.ftensor4('kerns')
out = T.ftensor4('out')
img_val = numpy.asarray( img_val = numpy.asarray(
numpy.random.rand(3, 4, 5, 6), numpy.random.rand(7, 2, 6, 4),
dtype='float32' dtype='float32'
) )
kern_vals = numpy.asarray( kern_vals = numpy.asarray(
numpy.random.rand(3, 4, 5, 6), numpy.random.rand(8, 2, 4, 3),
dtype='float32' dtype='float32'
) )
...@@ -274,16 +275,21 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -274,16 +275,21 @@ class TestDnnInferShapes(utt.InferShapeTester):
[(1, 1), (2, 2)], [(1, 1), (2, 2)],
['conv', 'cross'] ['conv', 'cross']
): ):
out_vals = numpy.zeros(
dnn.GpuDnnConv.get_out_shape(img_val.shape, kern_vals.shape,
border_mode=params[0],
subsample=params[1]),
dtype='float32')
desc = dnn.GpuDnnConvDesc( desc = dnn.GpuDnnConvDesc(
border_mode=params[0], border_mode=params[0],
subsample=params[1], subsample=params[1],
conv_mode=params[2] conv_mode=params[2]
)(img.shape, kerns.shape) )(img.shape, kerns.shape)
conv = dnn.GpuDnnConv()(img_val, kern_vals, desc) conv = dnn.GpuDnnConv()(img, kerns, out, desc)
self._compile_and_check( self._compile_and_check(
[img, kerns], [img, kerns, out],
[conv], [conv],
[img_val, kern_vals], [img_val, kern_vals, out_vals],
dnn.GpuDnnConv dnn.GpuDnnConv
) )
...@@ -292,14 +298,16 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -292,14 +298,16 @@ class TestDnnInferShapes(utt.InferShapeTester):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img = T.ftensor4('img') img = T.ftensor4('img')
kerns = T.ftensor4('kerns') kerns = T.ftensor4('kerns')
out = T.ftensor4('out')
img_val = numpy.asarray( img_val = numpy.asarray(
numpy.random.rand(3, 4, 5, 6), numpy.random.rand(2, 5, 6, 8),
dtype='float32' dtype='float32'
) )
kern_vals = numpy.asarray( kern_vals = numpy.asarray(
numpy.random.rand(3, 4, 5, 6), numpy.random.rand(2, 1, 5, 6),
dtype='float32' dtype='float32'
) )
out_vals = numpy.zeros((3, 3, 1, 1), dtype='float32')
for params in product( for params in product(
['valid', 'full'], ['valid', 'full'],
...@@ -311,27 +319,27 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -311,27 +319,27 @@ class TestDnnInferShapes(utt.InferShapeTester):
if params[2] == 'conv': if params[2] == 'conv':
temp_kerns = temp_kerns[:, :, ::-1, ::-1] temp_kerns = temp_kerns[:, :, ::-1, ::-1]
temp_kerns = temp_kerns.dimshuffle(1, 0, 2, 3) temp_kerns = temp_kerns.dimshuffle(1, 0, 2, 3)
shape = theano.tensor.stack( shape = (
temp_kerns.shape[1], temp_img.shape[1], kern_vals.shape[1], img_val.shape[1],
temp_img.shape[2] - temp_kerns.shape[2] + 1, img_val.shape[2] - kern_vals.shape[2] + 1,
temp_img.shape[3] - temp_kerns.shape[3] + 1 img_val.shape[3] - kern_vals.shape[3] + 1
) )
out_vals = numpy.zeros(shape, dtype='float32')
desc = dnn.GpuDnnConvDesc( desc = dnn.GpuDnnConvDesc(
border_mode=params[0], border_mode=params[0],
subsample=params[1], subsample=params[1],
conv_mode=params[2] conv_mode=params[2]
)(temp_img.shape, shape) )(temp_img.shape, out.shape)
conv_grad_w = dnn.GpuDnnConvGradW()( conv_grad_w = dnn.GpuDnnConvGradW()(
temp_img, temp_img,
temp_kerns, temp_kerns,
out,
desc, desc,
shape[2],
shape[3]
) )
self._compile_and_check( self._compile_and_check(
[temp_img, temp_kerns], [temp_img, temp_kerns, out],
[conv_grad_w], [conv_grad_w],
[img_val, kern_vals], [img_val, kern_vals, out_vals],
dnn.GpuDnnConvGradW dnn.GpuDnnConvGradW
) )
...@@ -340,6 +348,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -340,6 +348,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img = T.ftensor4('img') img = T.ftensor4('img')
kerns = T.ftensor4('kerns') kerns = T.ftensor4('kerns')
out = T.ftensor4('out')
img_val = numpy.asarray( img_val = numpy.asarray(
numpy.random.rand(3, 4, 5, 6), numpy.random.rand(3, 4, 5, 6),
dtype='float32' dtype='float32'
...@@ -354,29 +363,28 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -354,29 +363,28 @@ class TestDnnInferShapes(utt.InferShapeTester):
[(1, 1)], [(1, 1)],
['conv', 'cross'] ['conv', 'cross']
): ):
print params
temp_kerns = kerns.dimshuffle(1, 0, 2, 3) temp_kerns = kerns.dimshuffle(1, 0, 2, 3)
shape = theano.tensor.stack( shape = (
img.shape[0], temp_kerns.shape[1], img_val.shape[0], kern_vals.shape[1],
img.shape[2] + temp_kerns.shape[2] - 1, img_val.shape[2] + kern_vals.shape[2] - 1,
img.shape[3] + temp_kerns.shape[3] - 1 img_val.shape[3] + kern_vals.shape[3] - 1
) )
out_vals = numpy.zeros(shape, dtype='float32')
desc = dnn.GpuDnnConvDesc( desc = dnn.GpuDnnConvDesc(
border_mode=params[0], border_mode=params[0],
subsample=params[1], subsample=params[1],
conv_mode=params[2] conv_mode=params[2]
)(shape, temp_kerns.shape) )(out.shape, temp_kerns.shape)
conv_grad_i = dnn.GpuDnnConvGradI()( conv_grad_i = dnn.GpuDnnConvGradI()(
temp_kerns, temp_kerns,
img, img,
out,
desc, desc,
shape[2],
shape[3]
) )
self._compile_and_check( self._compile_and_check(
[temp_kerns, img], [temp_kerns, img, out],
[conv_grad_i], [conv_grad_i],
[kern_vals, img_val], [kern_vals, img_val, out_vals],
dnn.GpuDnnConvGradI dnn.GpuDnnConvGradI
) )
...@@ -447,6 +455,100 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -447,6 +455,100 @@ class TestDnnInferShapes(utt.InferShapeTester):
dnn.GpuDnnPoolGrad dnn.GpuDnnPoolGrad
) )
def test_dnn_conv_merge():
img = T.ftensor4()
kern = T.ftensor4()
out = T.ftensor4()
b = 1
c = 4
f = 3
ih = 2
iw = 8
kh = 2
kw = 2
img_val = numpy.random.random((b, c, ih, iw)).astype('float32')
kern_val = numpy.random.random((f, c, kh, kw)).astype('float32')
out_val = numpy.random.random((b, f, ih-kw+1, iw-kw+1)).astype('float32')
conv = dnn.dnn_conv(img, kern)
gw = theano.grad(conv.sum(), kern)
gi = theano.grad(conv.sum(), img)
lr = numpy.asarray(0.05, dtype='float32')
fr = out - lr * conv
wr = kern - lr * gw
ir = img - lr * gi
f1 = theano.function([img, kern, out], [fr, wr, ir], mode=mode_with_gpu)
assert isinstance(f1.maker.fgraph.outputs[0].owner.inputs[0].owner.op,
dnn.GpuDnnConv)
assert isinstance(f1.maker.fgraph.outputs[1].owner.inputs[0].owner.op,
dnn.GpuDnnConvGradW)
assert isinstance(f1.maker.fgraph.outputs[2].owner.inputs[0].owner.op,
dnn.GpuDnnConvGradI)
mode = mode_with_gpu
mode = mode.excluding('local_dnn_conv_alpha_merge')
mode = mode.excluding('local_dnn_convw_alpha_merge')
mode = mode.excluding('local_dnn_convi_alpha_merge')
mode = mode.excluding('local_dnn_conv_output_merge')
mode = mode.excluding('local_dnn_convw_output_merge')
mode = mode.excluding('local_dnn_convi_output_merge')
f2 = theano.function([img, kern, out], [fr, wr, ir], mode=mode)
assert not isinstance(f2.maker.fgraph.outputs[0].owner.inputs[0].owner.op,
dnn.GpuDnnConv)
assert not isinstance(f2.maker.fgraph.outputs[1].owner.inputs[0].owner.op,
dnn.GpuDnnConvGradW)
assert not isinstance(f2.maker.fgraph.outputs[2].owner.inputs[0].owner.op,
dnn.GpuDnnConvGradI)
out_f1 = f1(img_val, kern_val, out_val)
out_f2 = f2(img_val, kern_val, out_val)
assert len(out_f1) == len(out_f2)
for v1, v2 in zip(out_f1, out_f2):
utt.assert_allclose(v1, v2)
def test_dnn_conv_grad():
if dnn.version() == -1:
raise SkipTest('alpha != 1.0 not supported in cudnn v1')
b = 1
c = 4
f = 3
ih = 2
iw = 8
kh = 2
kw = 2
img_val = numpy.random.random((b, c, ih, iw)).astype('float32')
kern_val = numpy.random.random((f, c, kh, kw)).astype('float32')
out_val = numpy.random.random((b, f, ih-kw+1, iw-kw+1)).astype('float32')
def dconv(img, kern, out):
desc = dnn.GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode='conv')(img.shape, kern.shape)
return dnn.GpuDnnConv()(img, kern, out, desc)
def dconvi(img, kern, out):
desc = dnn.GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode='conv')(img.shape, kern.shape)
return dnn.GpuDnnConvGradI()(kern, out, img, desc)
def dconvw(img, kern, out):
desc = dnn.GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode='conv')(img.shape, kern.shape)
return dnn.GpuDnnConvGradW()(img, out, kern, desc)
utt.verify_grad(dconv, [img_val, kern_val, out_val])
utt.verify_grad(dconvi, [img_val, kern_val, out_val])
utt.verify_grad(dconvw, [img_val, kern_val, out_val])
def test_version(): def test_version():
if not cuda.dnn.dnn_available(): if not cuda.dnn.dnn_available():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论