提交 7ae6897c authored 作者: Frederic Bastien's avatar Frederic Bastien

make GpuSoftmax and GpuSoftmaxWithBias loop when their is too much block. Add test for this.

上级 51a4b704
...@@ -303,7 +303,7 @@ class GpuSoftmax (Op): ...@@ -303,7 +303,7 @@ class GpuSoftmax (Op):
return shape return shape
def c_code_cache_version(self): def c_code_cache_version(self):
#return () #return ()
return (1,) + inline_softmax.code_version return (2,) + inline_softmax.code_version
def c_code(self, node, nodename, (x,), (z,), sub): def c_code(self, node, nodename, (x,), (z,), sub):
fail = sub['fail'] fail = sub['fail']
return """ return """
...@@ -330,7 +330,7 @@ class GpuSoftmax (Op): ...@@ -330,7 +330,7 @@ class GpuSoftmax (Op):
kSoftmax_%(nodename)s kSoftmax_%(nodename)s
<<< <<<
// todo: cap these at the card limits, implement loops in kernel // todo: cap these at the card limits, implement loops in kernel
CudaNdarray_HOST_DIMS(%(x)s)[0], std::min(CudaNdarray_HOST_DIMS(%(x)s)[0],32*1024),
CudaNdarray_HOST_DIMS(%(x)s)[1], CudaNdarray_HOST_DIMS(%(x)s)[1],
CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float) CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float)
>>>( >>>(
...@@ -362,11 +362,14 @@ class GpuSoftmax (Op): ...@@ -362,11 +362,14 @@ class GpuSoftmax (Op):
body=[ body=[
"extern __shared__ float buf[]", "extern __shared__ float buf[]",
"float * buf2 = buf + N", "float * buf2 = buf + N",
"buf[threadIdx.x] = x[blockIdx.x * sx0 + threadIdx.x * sx1]", "for (int blockIDX = blockIdx.x; blockIDX < M; blockIDX += gridDim.x){",
"buf2[threadIdx.x] = buf[threadIdx.x]", "buf[threadIdx.x] = x[blockIDX * sx0 + threadIdx.x * sx1]",
"__syncthreads()", "buf2[threadIdx.x] = buf[threadIdx.x]",
inline_softmax('N', 'buf', 'buf2', 'threadIdx.x', 'blockDim.x'), "__syncthreads()",
"sm[blockIdx.x * N + threadIdx.x] = buf[threadIdx.x]" inline_softmax('N', 'buf', 'buf2', 'threadIdx.x', 'blockDim.x'),
"sm[blockIDX * N + threadIdx.x] = buf[threadIdx.x]",
"__syncthreads()",
"}",
]) ])
...@@ -386,7 +389,7 @@ class GpuSoftmaxWithBias (Op): ...@@ -386,7 +389,7 @@ class GpuSoftmaxWithBias (Op):
return [shape[0]] return [shape[0]]
def c_code_cache_version(self): def c_code_cache_version(self):
#return () #return ()
return (1,) + inline_softmax.code_version return (2,) + inline_softmax.code_version
def c_code(self, node, nodename, (x,b), (z,), sub): def c_code(self, node, nodename, (x,b), (z,), sub):
fail = sub['fail'] fail = sub['fail']
...@@ -425,7 +428,7 @@ class GpuSoftmaxWithBias (Op): ...@@ -425,7 +428,7 @@ class GpuSoftmaxWithBias (Op):
kSoftmaxWithBias_%(nodename)s kSoftmaxWithBias_%(nodename)s
<<< <<<
// todo: cap these at the card limits, implement loops in kernel // todo: cap these at the card limits, implement loops in kernel
CudaNdarray_HOST_DIMS(%(x)s)[0], std::min(CudaNdarray_HOST_DIMS(%(x)s)[0],32*1024),
CudaNdarray_HOST_DIMS(%(x)s)[1], CudaNdarray_HOST_DIMS(%(x)s)[1],
CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float) CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float)
>>>( >>>(
...@@ -461,10 +464,14 @@ class GpuSoftmaxWithBias (Op): ...@@ -461,10 +464,14 @@ class GpuSoftmaxWithBias (Op):
body=[ body=[
"extern __shared__ float buf[]", "extern __shared__ float buf[]",
"float * buf2 = buf + N", "float * buf2 = buf + N",
"buf[threadIdx.x] = x[blockIdx.x * sx0 + threadIdx.x * sx1]", "for (int blockIDX = blockIdx.x; blockIDX < M; blockIDX += gridDim.x){",
"buf[threadIdx.x] += b[threadIdx.x * sb0]", "buf[threadIdx.x] = x[blockIDX * sx0 + threadIdx.x * sx1]",
"buf2[threadIdx.x] = buf[threadIdx.x]", "buf[threadIdx.x] += b[threadIdx.x * sb0]",
"__syncthreads()", "buf2[threadIdx.x] = buf[threadIdx.x]",
inline_softmax('N', 'buf', 'buf2', 'threadIdx.x', 'blockDim.x'), "__syncthreads()",
"sm[blockIdx.x * N + threadIdx.x] = buf[threadIdx.x]" inline_softmax('N', 'buf', 'buf2', 'threadIdx.x', 'blockDim.x'),
"sm[blockIDX * N + threadIdx.x] = buf[threadIdx.x]",
"__syncthreads()",
"}",
]) ])
#for (int i = blockIdx.x; i < N; i += gridDim.x)
...@@ -17,6 +17,10 @@ else: ...@@ -17,6 +17,10 @@ else:
def test_GpuCrossentropySoftmax1HotWithBiasDx(): def test_GpuCrossentropySoftmax1HotWithBiasDx():
""" """
This is basic test for GpuCrossentropySoftmaxArgmax1HotWithBias and GpuCrossentropySoftmax1HotWithBiasDx This is basic test for GpuCrossentropySoftmaxArgmax1HotWithBias and GpuCrossentropySoftmax1HotWithBiasDx
We check that we loop when their is too much threads
TODO: check that we loop when their is too much block(>32*1024)
""" """
n_in = 1000 n_in = 1000
...@@ -61,3 +65,53 @@ def test_GpuCrossentropySoftmax1HotWithBiasDx(): ...@@ -61,3 +65,53 @@ def test_GpuCrossentropySoftmax1HotWithBiasDx():
assert numpy.allclose(out[1],gout[1]) assert numpy.allclose(out[1],gout[1])
assert numpy.allclose(out[2],gout[2],atol=2e-6) assert numpy.allclose(out[2],gout[2],atol=2e-6)
def test_softmax_with_bias():
"""
This is basic test for GpuSoftmaxWithBias
We check that we loop when their is too much block
TODO: check that we loop when their is too much thread.(THIS IS NOT IMPLEMENTED)
"""
x = T.fmatrix('x')
#we need to test n>32*1024 to check that we make the block loop.
n,m=2<<15,5
data = numpy.arange(n*m, dtype='float32').reshape(n,m)
z = T.nnet.softmax_with_bias(x, T.zeros_like(x[0,:]))
f = theano.function([x],z, mode=mode_without_gpu)
f_gpu = theano.function([x],z, mode=mode_with_gpu)
assert f.maker.env.toposort()[-1].op==T.nnet.softmax_with_bias
assert isinstance(f_gpu.maker.env.toposort()[-2].op,cuda.nnet.GpuSoftmaxWithBias)
out=f(data)
gout=f_gpu(data)
assert numpy.allclose(out,gout),numpy.absolute(out-gout)
def test_softmax():
"""
This is basic test for GpuSoftmax
We check that we loop when their is too much block
TODO: check that we loop when their is too much thread.(THIS IS NOT IMPLEMENTED)
"""
x = T.fmatrix('x')
#we need to test n>32*1024 to check that we make the block loop.
n,m=2<<15,5
data = numpy.arange(n*m, dtype='float32').reshape(n,m)
z = T.nnet.softmax(x)
f = theano.function([x],z, mode=mode_without_gpu)
f_gpu = theano.function([x],z, mode=mode_with_gpu)
assert f.maker.env.toposort()[-1].op==T.nnet.softmax
assert isinstance(f_gpu.maker.env.toposort()[-2].op,cuda.nnet.GpuSoftmax)
out=f(data)
gout=f_gpu(data)
assert numpy.allclose(out,gout),numpy.absolute(out-gout)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论