提交 568cfdd4 authored 作者: Frederic Bastien's avatar Frederic Bastien

Allow GpuSoftmax and GpuSoftmaxWithBias to work with bigger input.

上级 829da692
......@@ -88,6 +88,7 @@ GPU:
* Fixed a bug if you crated a view of a manually created CudaNdarray that are view of GPUArray.
* Removed a warning when nvcc is not available and the user did not requested it.
* renamed config option cuda.nvccflags -> nvcc.flags
* Allow GpuSoftmax and GpuSoftmaxWithBias to work with bigger input.
Bugs fixed:
......
......@@ -105,13 +105,20 @@ def inline_reduce_prod(N, buf, pos, count):
return inline_reduce(N, buf, pos, count, lambda a, b: "%s * %s"%(a,b))
@code_version((1,) + inline_reduce_max.code_version + inline_reduce_sum.code_version)
@code_version((2,) + inline_reduce_max.code_version + inline_reduce_sum.code_version)
def inline_softmax(N, buf, buf2, threadPos, threadCount):
"""
:param N: length of the buffer
:param threadPos: index of executing thread
:param threadCount: number of executing threads
:Precondition: buf and buf2 contain two identical copies of the input to softmax
:Postcondition: buf contains the softmax, buf2 contains un-normalized softmax
:note: buf and buf2 should be in gpu shared memory, we access it many times.
:note2: We use __i as an int variable in a loop
"""
return [
#get max of buf (trashing all but buf[0])
......@@ -119,14 +126,18 @@ def inline_softmax(N, buf, buf2, threadPos, threadCount):
'__syncthreads()',
'float row_max = '+buf+'[0]',
'__syncthreads()',
buf+'['+threadPos+'] = exp('+buf2+'['+threadPos+'] - row_max)',
buf2+'['+threadPos+'] = '+buf+'['+threadPos+']',
'for(int __i='+threadPos+'; __i<'+N+'; __i+='+threadCount+'){',
buf+'[__i] = exp('+buf2+'[__i] - row_max)',
buf2+'[__i] = '+buf+'[__i]',
'}',
'__syncthreads()',
inline_reduce_sum(N, buf, threadPos, threadCount),
'__syncthreads()',
'float row_sum = '+buf+'[0]',
'__syncthreads()',
# divide each exp() result by the sum to complete the job.
buf+'['+threadPos+'] = '+buf2+'['+threadPos+'] / row_sum'
'for(int __i='+threadPos+'; __i<'+N+'; __i+='+threadCount+'){',
buf+'[__i] = '+buf2+'[__i] / row_sum',
'}',
'__syncthreads()',
]
......@@ -309,7 +309,7 @@ class GpuSoftmax (Op):
return shape
def c_code_cache_version(self):
#return ()
return (3,) + inline_softmax.code_version
return (4,) + inline_softmax.code_version
def c_code(self, node, nodename, inp, out, sub):
x, = inp
z, = out
......@@ -335,12 +335,17 @@ class GpuSoftmax (Op):
}
}
{
int n_blocks = std::min(CudaNdarray_HOST_DIMS(%(x)s)[0],32*1024);
//TODO, detect the maximum number of thread per block.
int n_threads = std::min(CudaNdarray_HOST_DIMS(%(x)s)[1], 1024);
int n_shared_bytes = CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float);
kSoftmax_%(nodename)s
<<<
// todo: cap these at the card limits, implement loops in kernel
std::min(CudaNdarray_HOST_DIMS(%(x)s)[0],32*1024),
CudaNdarray_HOST_DIMS(%(x)s)[1],
CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float)
n_blocks,
n_threads,
n_shared_bytes
>>>(
CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1],
......@@ -371,11 +376,15 @@ class GpuSoftmax (Op):
"extern __shared__ float buf[]",
"float * buf2 = buf + N",
"for (int blockIDX = blockIdx.x; blockIDX < M; blockIDX += gridDim.x){",
"buf[threadIdx.x] = x[blockIDX * sx0 + threadIdx.x * sx1]",
"buf2[threadIdx.x] = buf[threadIdx.x]",
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){",
"buf[tx] = x[blockIDX * sx0 + tx * sx1]",
"buf2[tx] = buf[tx]",
"}",
"__syncthreads()",
inline_softmax('N', 'buf', 'buf2', 'threadIdx.x', 'blockDim.x'),
"sm[blockIDX * N + threadIdx.x] = buf[threadIdx.x]",
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){",
"sm[blockIDX * N + tx] = buf[tx]",# This set all value correctly
"}",
"__syncthreads()",
"}",
])
......@@ -398,7 +407,7 @@ class GpuSoftmaxWithBias (Op):
return [shape[0]]
def c_code_cache_version(self):
#return ()
return (3,) + inline_softmax.code_version
return (4,) + inline_softmax.code_version
def c_code(self, node, nodename, inp, out, sub):
x, b = inp
......@@ -436,12 +445,17 @@ class GpuSoftmaxWithBias (Op):
}
}
{
int n_blocks = std::min(CudaNdarray_HOST_DIMS(%(x)s)[0],32*1024);
//TODO, detect the maximum number of thread per block.
int n_threads = std::min(CudaNdarray_HOST_DIMS(%(x)s)[1], 1024);
int n_shared_bytes = CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float);
kSoftmaxWithBias_%(nodename)s
<<<
// todo: cap these at the card limits, implement loops in kernel
std::min(CudaNdarray_HOST_DIMS(%(x)s)[0],32*1024),
CudaNdarray_HOST_DIMS(%(x)s)[1],
CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float)
n_blocks,
n_threads,
n_shared_bytes
>>>(
CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1],
......@@ -476,13 +490,17 @@ class GpuSoftmaxWithBias (Op):
"extern __shared__ float buf[]",
"float * buf2 = buf + N",
"for (int blockIDX = blockIdx.x; blockIDX < M; blockIDX += gridDim.x){",
"buf[threadIdx.x] = x[blockIDX * sx0 + threadIdx.x * sx1]",
"buf[threadIdx.x] += b[threadIdx.x * sb0]",
"buf2[threadIdx.x] = buf[threadIdx.x]",
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){",
"buf[tx] = x[blockIDX * sx0 + tx * sx1]",
"buf[tx] += b[tx * sb0]",
"buf2[tx] = buf[tx]",
"}",
"__syncthreads()",
inline_softmax('N', 'buf', 'buf2', 'threadIdx.x', 'blockDim.x'),
"sm[blockIDX * N + threadIdx.x] = buf[threadIdx.x]",
"__syncthreads()",
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){",
"sm[blockIDX * N + tx] = buf[tx]",
"}",
"__syncthreads()",
"}",
])
......
......@@ -142,12 +142,6 @@ def test_softmax_with_bias():
TODO: check that we loop when their is too much thread.(THIS IS NOT IMPLEMENTED)
"""
x = T.fmatrix('x')
#we need to test n>32*1024 to check that we make the block loop.
n,m=2<<15,5
data = numpy.arange(n*m, dtype='float32').reshape(n,m)
z = T.nnet.softmax_with_bias(x, T.zeros_like(x[0,:]))
f = theano.function([x],z, mode=mode_without_gpu)
......@@ -155,9 +149,27 @@ def test_softmax_with_bias():
assert f.maker.env.toposort()[-1].op==T.nnet.softmax_with_bias
assert isinstance(f_gpu.maker.env.toposort()[-2].op,cuda.nnet.GpuSoftmaxWithBias)
out=f(data)
gout=f_gpu(data)
assert numpy.allclose(out,gout),numpy.absolute(out-gout)
def cmp(n,m, catch=False):
"""Some old card won't accet the configuration arguments of this implementation."""
try:
#print "test_softmax",n,m
data = numpy.arange(n*m, dtype='float32').reshape(n,m)
out=f(data)
gout=f_gpu(data)
assert numpy.allclose(out,gout),numpy.absolute(out-gout)
except RuntimeError, e:
if not catch:
raise
assert e.args[0]=='Cuda error: kSoftmax_node_0: invalid configuration argument.\n'
cmp(2, 5)
#we need to test n>32*1024 to check that we make the block loop.
cmp(2<<15, 5)
cmp(4074, 400)
cmp(4, 1000, True)
cmp(4, 1024, True)
cmp(4, 2000, True)
cmp(4, 2024, True)
cmp(4, 4074, True)
def test_softmax():
"""
......@@ -168,18 +180,31 @@ def test_softmax():
"""
x = T.fmatrix('x')
#we need to test n>32*1024 to check that we make the block loop.
n,m=2<<15,5
data = numpy.arange(n*m, dtype='float32').reshape(n,m)
z = T.nnet.softmax(x)
f = theano.function([x],z, mode=mode_without_gpu)
f_gpu = theano.function([x],z, mode=mode_with_gpu)
assert f.maker.env.toposort()[-1].op==T.nnet.softmax
assert isinstance(f_gpu.maker.env.toposort()[-2].op,cuda.nnet.GpuSoftmax)
out=f(data)
gout=f_gpu(data)
assert numpy.allclose(out,gout),numpy.absolute(out-gout)
def cmp(n,m, catch=False):
"""Some old card won't accet the configuration arguments of this implementation."""
try:
#print "test_softmax",n,m
data = numpy.arange(n*m, dtype='float32').reshape(n,m)
out=f(data)
gout=f_gpu(data)
assert numpy.allclose(out,gout),numpy.absolute(out-gout)
except RuntimeError, e:
if not catch:
raise
assert e.args[0]=='Cuda error: kSoftmax_node_0: invalid configuration argument.\n'
#we need to test n>32*1024 to check that we make the block loop.
cmp(2, 5)
cmp(2<<15, 5)
cmp(4074, 400)
cmp(4, 1000, True)
cmp(4, 1024, True)
cmp(4, 2000, True)
cmp(4, 2024, True)
cmp(4, 4074, True)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论