提交 e3c3d33c authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Copy of ops GpuSoftmax and GpuSoftmaxWithBias from theano/sandbox/cuda/nnet.py…

Copy of ops GpuSoftmax and GpuSoftmaxWithBias from theano/sandbox/cuda/nnet.py to theano/sandbox/gpuarray/nnet.py
上级 94506569
...@@ -440,3 +440,350 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op): ...@@ -440,3 +440,350 @@ class GpuCrossentropySoftmax1HotWithBiasDx(Op):
return ['cuda_get_ptr = (CUdeviceptr (*)(gpudata *g))compyte_get_extension("cuda_get_ptr");'] return ['cuda_get_ptr = (CUdeviceptr (*)(gpudata *g))compyte_get_extension("cuda_get_ptr");']
gpu_crossentropy_softmax_1hot_with_bias_dx = GpuCrossentropySoftmax1HotWithBiasDx() gpu_crossentropy_softmax_1hot_with_bias_dx = GpuCrossentropySoftmax1HotWithBiasDx()
class GpuSoftmax (GpuOp):
"""
Implement Softmax on the gpu.
"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def __str__(self):
return self.__class__.__name__
def make_node(self, x):
return Apply(self, [x], [x.type()])
def infer_shape(self, node, shape):
return shape
def c_code_cache_version(self):
return (9,) + inline_softmax.code_version
def c_code(self, node, nodename, inp, out, sub):
x, = inp
z, = out
fail = sub['fail']
return """
if (%(x)s->nd != 2)
{
PyErr_SetString(PyExc_ValueError, "rank error");
%(fail)s;
}
if ((NULL == %(z)s) ||
(CudaNdarray_HOST_DIMS(%(z)s)[0] !=
CudaNdarray_HOST_DIMS(%(x)s)[0]) ||
(CudaNdarray_HOST_DIMS(%(z)s)[1] !=
CudaNdarray_HOST_DIMS(%(x)s)[1]))
{
Py_XDECREF(%(z)s);
%(z)s = (CudaNdarray*)CudaNdarray_New();
if ((NULL == %(z)s)
|| CudaNdarray_alloc_contiguous(%(z)s, 2,
CudaNdarray_HOST_DIMS(%(x)s)))
{
Py_XDECREF(%(z)s);
%(z)s = NULL;
%(fail)s;
}
}
{
int n_blocks = std::min(CudaNdarray_HOST_DIMS(%(x)s)[0],
32 * 1024);
//TODO, detect the maximum number of thread per block.
int n_threads = std::min(CudaNdarray_HOST_DIMS(%(x)s)[1], 512);
int n_shared_bytes = CudaNdarray_HOST_DIMS(%(x)s)[1] *
2 * sizeof(float);
if (CudaNdarray_HOST_DIMS(%(x)s)[0] > 0)
{
//Those numbers are based on not too recent GPU
//to make them compatible with more GPU.
//TODO: read the information from the card.
if(n_shared_bytes < (32 * 1024 - 500)){
kSoftmax_%(nodename)s
<<<
n_blocks,
n_threads,
n_shared_bytes
>>>(
CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1],
CudaNdarray_DEV_DATA(%(x)s),
CudaNdarray_HOST_STRIDES(%(x)s)[0],
CudaNdarray_HOST_STRIDES(%(x)s)[1],
CudaNdarray_DEV_DATA(%(z)s),
CudaNdarray_HOST_STRIDES(%(z)s)[0],
CudaNdarray_HOST_STRIDES(%(z)s)[1]
);
}else{
kSoftmax_fixed_shared%(nodename)s
<<<
n_blocks,
n_threads,
n_threads * sizeof(float)
>>>(
CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1],
CudaNdarray_DEV_DATA(%(x)s),
CudaNdarray_HOST_STRIDES(%(x)s)[0],
CudaNdarray_HOST_STRIDES(%(x)s)[1],
CudaNdarray_DEV_DATA(%(z)s),
CudaNdarray_HOST_STRIDES(%(z)s)[0],
CudaNdarray_HOST_STRIDES(%(z)s)[1]
);
}
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError,
"Cuda error: %%s: %%s.\\n Used %%d blocks,"
" %%d threads %%d bytes of shared memory",
"kSoftmax[_fixed_shared]%(nodename)s",
cudaGetErrorString(err),
n_blocks, n_threads, n_shared_bytes);
%(fail)s;
}
}
}
assert(%(z)s);
""" % locals()
def c_support_code_apply(self, node, nodename):
ret1 = nvcc_kernel("kSoftmax_%s" % nodename,
params=['int M', 'int N',
'const float * x', 'const int sx0', 'const int sx1',
'float * sm', 'const int sm_s0', 'const int sm_s1'],
body=[
"extern __shared__ float buf[]",
"float * buf2 = buf + N",
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){",
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){",
"buf[tx] = x[blockIDX * sx0 + tx * sx1]",
"buf2[tx] = buf[tx]",
"}",
"__syncthreads()",
inline_softmax('N', 'buf', 'buf2',
'threadIdx.x', 'blockDim.x'),
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){",
# This set all value correctly
"sm[blockIDX * sm_s0 + tx * sm_s1] = buf[tx]",
"}",
"__syncthreads()",
"}",
])
ret2 = nvcc_kernel("kSoftmax_fixed_shared%s" % nodename,
params=['int M', 'int N',
'const float * x', 'const int sx0', 'const int sx1',
'float * sm', 'const int sm_s0', 'const int sm_s1'],
body=[
"extern __shared__ float buf[]",
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){",
"const float *x_ptr = &x[blockIDX * sx0]",
"float *sm_ptr = &sm[blockIDX * sm_s0]",
inline_softmax_fixed_shared('N', 'buf', 'x_ptr', 'sx1',
'sm_ptr', 'sm_s1',
'threadIdx.x', 'blockDim.x'),
"__syncthreads()",
"}",
])
return ret1 + "\n" + ret2
gpu_softmax = GpuSoftmax()
class GpuSoftmaxWithBias (GpuOp):
"""
Implement SoftmaxWithBias on the gpu.
"""
nin = 2
nout = 1
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def __str__(self):
return self.__class__.__name__
def make_node(self, x, b):
return Apply(self, [x, b], [x.type()])
def infer_shape(self, node, shape):
return [shape[0]]
def c_code_cache_version(self):
#return ()
return (8,) + inline_softmax.code_version
def c_code(self, node, nodename, inp, out, sub):
x, b = inp
z, = out
fail = sub['fail']
return """
if (%(x)s->nd != 2)
{
PyErr_SetString(PyExc_ValueError, "rank error input");
%(fail)s;
}
if (%(b)s->nd != 1)
{
PyErr_SetString(PyExc_ValueError, "rank error for the bias");
%(fail)s;
}
if ((CudaNdarray_HOST_DIMS(%(x)s)[1] !=
CudaNdarray_HOST_DIMS(%(b)s)[0]))
{
PyErr_Format(PyExc_ValueError,
"number of columns in x (%%ld)"
" does not match length of b (%%ld)",
(long int)CudaNdarray_HOST_DIMS(%(x)s)[1],
(long int)CudaNdarray_HOST_DIMS(%(b)s)[0]);
%(fail)s;
}
if ((NULL == %(z)s)
|| (CudaNdarray_HOST_DIMS(%(z)s)[0] !=
CudaNdarray_HOST_DIMS(%(x)s)[0])
|| (CudaNdarray_HOST_DIMS(%(z)s)[1] !=
CudaNdarray_HOST_DIMS(%(x)s)[1]))
{
Py_XDECREF(%(z)s);
%(z)s = (CudaNdarray*)CudaNdarray_New();
if ((NULL == %(z)s)
|| CudaNdarray_alloc_contiguous(%(z)s, 2,
CudaNdarray_HOST_DIMS(%(x)s)))
{
Py_XDECREF(%(z)s);
%(z)s = NULL;
%(fail)s;
}
}
{
int n_blocks = std::min(CudaNdarray_HOST_DIMS(%(x)s)[0],32*1024);
//TODO, detect the maximum number of thread per block.
int n_threads = std::min(CudaNdarray_HOST_DIMS(%(x)s)[1], 512);
int n_shared_bytes = CudaNdarray_HOST_DIMS(%(x)s)[1] *
2 * sizeof(float);
if (CudaNdarray_HOST_DIMS(%(x)s)[0] > 0)
{
if(n_shared_bytes < (32 * 1024 - 500)){
kSoftmaxWithBias_%(nodename)s
<<<
n_blocks,
n_threads,
n_shared_bytes
>>>(
CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1],
CudaNdarray_DEV_DATA(%(x)s),
CudaNdarray_HOST_STRIDES(%(x)s)[0],
CudaNdarray_HOST_STRIDES(%(x)s)[1],
CudaNdarray_DEV_DATA(%(b)s),
CudaNdarray_HOST_STRIDES(%(b)s)[0],
CudaNdarray_DEV_DATA(%(z)s),
CudaNdarray_HOST_STRIDES(%(z)s)[0],
CudaNdarray_HOST_STRIDES(%(z)s)[1]
);
}else{
kSoftmaxWithBias_fixed_shared%(nodename)s
<<<
n_blocks,
n_threads,
n_threads * sizeof(float)
>>>(
CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1],
CudaNdarray_DEV_DATA(%(x)s),
CudaNdarray_HOST_STRIDES(%(x)s)[0],
CudaNdarray_HOST_STRIDES(%(x)s)[1],
CudaNdarray_DEV_DATA(%(b)s),
CudaNdarray_HOST_STRIDES(%(b)s)[0],
CudaNdarray_DEV_DATA(%(z)s),
CudaNdarray_HOST_STRIDES(%(z)s)[0],
CudaNdarray_HOST_STRIDES(%(z)s)[1]
);
}
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError,
"Cuda error: %%s: %%s.\\n",
"kSoftmaxWithBias_%(nodename)s",
cudaGetErrorString(err));
%(fail)s;
}
}
}
assert(%(z)s);
""" % locals()
def c_support_code_apply(self, node, nodename):
ret1 = nvcc_kernel("kSoftmaxWithBias_%s" % nodename,
params=['int M', 'int N',
'const float * x', 'const int sx0', 'const int sx1',
'const float * b', 'const int sb0',
'float * sm', 'const int sm_s0', 'const int sm_s1'],
body=[
"extern __shared__ float buf[]",
"float * buf2 = buf + N",
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){",
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){",
"buf[tx] = x[blockIDX * sx0 + tx * sx1]",
"buf[tx] += b[tx * sb0]",
"buf2[tx] = buf[tx]",
"}",
"__syncthreads()",
inline_softmax('N', 'buf', 'buf2',
'threadIdx.x', 'blockDim.x'),
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){",
"sm[blockIDX * sm_s0 + tx * sm_s1] = buf[tx]",
"}",
"__syncthreads()",
"}",
])
ret2 = nvcc_kernel("kSoftmaxWithBias_fixed_shared%s" % nodename,
params=['int M', 'int N',
'const float * x',
'const int sx0', 'const int sx1',
'const float * b', 'const int sb0',
'float * sm',
'const int sm_s0', 'const int sm_s1'],
body=[
"extern __shared__ float buf[]",
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){",
"const float *x_ptr = &x[blockIDX * sx0]",
"float *sm_ptr = &sm[blockIDX * sm_s0]",
inline_softmax_fixed_shared('N', 'buf',
'x_ptr', 'sx1',
'sm_ptr', 'sm_s1',
'threadIdx.x',
'blockDim.x',
'b', 'sb0'),
"__syncthreads()",
"}",
])
return ret1 + "\n" + ret2
gpu_softmax_with_bias = GpuSoftmaxWithBias()
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论