提交 5e93e6b8 authored 作者: Frederic's avatar Frederic

Make GpuSoftmax don't crash on GTX285 GPU in somes cases.

上级 7d18f6a6
...@@ -321,7 +321,7 @@ class GpuSoftmax (GpuOp): ...@@ -321,7 +321,7 @@ class GpuSoftmax (GpuOp):
return shape return shape
def c_code_cache_version(self): def c_code_cache_version(self):
#return () #return ()
return (5,) + inline_softmax.code_version return (6,) + inline_softmax.code_version
def c_code(self, node, nodename, inp, out, sub): def c_code(self, node, nodename, inp, out, sub):
x, = inp x, = inp
z, = out z, = out
...@@ -347,9 +347,9 @@ class GpuSoftmax (GpuOp): ...@@ -347,9 +347,9 @@ class GpuSoftmax (GpuOp):
} }
} }
{ {
int n_blocks = std::min(CudaNdarray_HOST_DIMS(%(x)s)[0],32*1024); int n_blocks = std::min(CudaNdarray_HOST_DIMS(%(x)s)[0], 32 * 1024);
//TODO, detect the maximum number of thread per block. //TODO, detect the maximum number of thread per block.
int n_threads = std::min(CudaNdarray_HOST_DIMS(%(x)s)[1], 1024); int n_threads = std::min(CudaNdarray_HOST_DIMS(%(x)s)[1], 512);
int n_shared_bytes = CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float); int n_shared_bytes = CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float);
kSoftmax_%(nodename)s kSoftmax_%(nodename)s
......
...@@ -238,15 +238,18 @@ def test_softmax(): ...@@ -238,15 +238,18 @@ def test_softmax():
except RuntimeError, e: except RuntimeError, e:
if not catch: if not catch:
raise raise
assert (e.args[0] == assert (e.args[0].startswith(
'Cuda error: kSoftmax_node_0: invalid configuration argument.\n') 'Cuda error: kSoftmax_node_0: invalid configuration argument.\n') or
e.args[0].startswith('Cuda error: kSoftmax_node_0: invalid argument.\n'))
#we need to test n>32*1024 to check that we make the block loop. #we need to test n>32*1024 to check that we make the block loop.
cmp(2, 5) cmp(2, 5)
cmp(2 << 15, 5) cmp(2 << 15, 5)
cmp(4074, 400) cmp(4074, 400)
cmp(4, 1000, True) cmp(784, 784)
cmp(4, 1024, True) cmp(4, 1000)
cmp(4, 2000, True) cmp(4, 1024)
cmp(4, 2024, True) cmp(4, 2000)
cmp(4, 2024)
#GTX285 don't have enought shared mem for this case.
cmp(4, 4074, True) cmp(4, 4074, True)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论