提交 d0040637 authored 作者: Frederic's avatar Frederic

make GpuSoftmax don't crash when input shape contains 0.

上级 defef0a5
...@@ -351,7 +351,7 @@ class GpuSoftmax (GpuOp): ...@@ -351,7 +351,7 @@ class GpuSoftmax (GpuOp):
def c_code_cache_version(self): def c_code_cache_version(self):
#return () #return ()
return (6,) + inline_softmax.code_version return (7,) + inline_softmax.code_version
def c_code(self, node, nodename, inp, out, sub): def c_code(self, node, nodename, inp, out, sub):
x, = inp x, = inp
...@@ -384,33 +384,36 @@ class GpuSoftmax (GpuOp): ...@@ -384,33 +384,36 @@ class GpuSoftmax (GpuOp):
int n_threads = std::min(CudaNdarray_HOST_DIMS(%(x)s)[1], 512); int n_threads = std::min(CudaNdarray_HOST_DIMS(%(x)s)[1], 512);
int n_shared_bytes = CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float); int n_shared_bytes = CudaNdarray_HOST_DIMS(%(x)s)[1] * 2 * sizeof(float);
kSoftmax_%(nodename)s if (CudaNdarray_HOST_DIMS(%(x)s)[0] > 0)
<<< {
n_blocks, kSoftmax_%(nodename)s
n_threads, <<<
n_shared_bytes n_blocks,
>>>( n_threads,
CudaNdarray_HOST_DIMS(%(x)s)[0], n_shared_bytes
CudaNdarray_HOST_DIMS(%(x)s)[1], >>>(
CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1],
CudaNdarray_DEV_DATA(%(x)s), CudaNdarray_DEV_DATA(%(x)s),
CudaNdarray_HOST_STRIDES(%(x)s)[0], CudaNdarray_HOST_STRIDES(%(x)s)[0],
CudaNdarray_HOST_STRIDES(%(x)s)[1], CudaNdarray_HOST_STRIDES(%(x)s)[1],
CudaNdarray_DEV_DATA(%(z)s), CudaNdarray_DEV_DATA(%(z)s),
CudaNdarray_HOST_STRIDES(%(z)s)[0], CudaNdarray_HOST_STRIDES(%(z)s)[0],
CudaNdarray_HOST_STRIDES(%(z)s)[1] CudaNdarray_HOST_STRIDES(%(z)s)[1]
); );
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if( cudaSuccess != err) if( cudaSuccess != err)
{ {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"Cuda error: %%s: %%s.\\n Used %%d blocks," "Cuda error: %%s: %%s.\\n Used %%d blocks,"
" %%d threads %%d bytes of shared memory", " %%d threads %%d bytes of shared memory",
"kSoftmax_%(nodename)s", cudaGetErrorString(err), "kSoftmax_%(nodename)s", cudaGetErrorString(err),
n_blocks, n_threads, n_shared_bytes); n_blocks, n_threads, n_shared_bytes);
%(fail)s; %(fail)s;
}
} }
} }
assert(%(z)s); assert(%(z)s);
......
...@@ -228,7 +228,10 @@ def test_softmax(): ...@@ -228,7 +228,10 @@ def test_softmax():
def cmp(n, m, catch=False): def cmp(n, m, catch=False):
"""Some old card won't accept the configuration arguments of """Some old card won't accept the configuration arguments of
this implementation. For those cases set catch=True to skip those errors.""" this implementation. For those cases set catch=True to skip
those errors.
"""
try: try:
#print "test_softmax",n,m #print "test_softmax",n,m
data = numpy.arange(n * m, dtype='float32').reshape(n, m) data = numpy.arange(n * m, dtype='float32').reshape(n, m)
...@@ -238,6 +241,7 @@ def test_softmax(): ...@@ -238,6 +241,7 @@ def test_softmax():
except RuntimeError, e: except RuntimeError, e:
if not catch: if not catch:
raise raise
# Different CUDA driver have different error message
assert (e.args[0].startswith( assert (e.args[0].startswith(
'Cuda error: kSoftmax_node_0: invalid configuration argument.\n') or 'Cuda error: kSoftmax_node_0: invalid configuration argument.\n') or
e.args[0].startswith('Cuda error: kSoftmax_node_0: invalid argument.\n')) e.args[0].startswith('Cuda error: kSoftmax_node_0: invalid argument.\n'))
...@@ -246,6 +250,7 @@ def test_softmax(): ...@@ -246,6 +250,7 @@ def test_softmax():
cmp(2, 5) cmp(2, 5)
cmp(2 << 15, 5) cmp(2 << 15, 5)
cmp(4074, 400) cmp(4074, 400)
cmp(0, 10)
cmp(784, 784) cmp(784, 784)
cmp(4, 1000) cmp(4, 1000)
cmp(4, 1024) cmp(4, 1024)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论