提交 791b8bcf authored 作者: Frederic Bastien's avatar Frederic Bastien

small code refactoring

上级 1ae40316
......@@ -1037,6 +1037,7 @@ class GpuSum(Op):
""" % locals()
def c_code_reduce_1011(self, sio, node, name, x, z, fail):
makecall = self._makecall(node, name, x, z, fail)
print >> sio, """
{
int verbose = 0;
......@@ -1044,13 +1045,11 @@ class GpuSum(Op):
std::min(CudaNdarray_HOST_DIMS(%(x)s)[3],
NUM_VECTOR_OP_THREADS_PER_BLOCK));
while (n_threads.y * n_threads.x < NUM_VECTOR_OP_THREADS_PER_BLOCK) ++n_threads.y;
n_threads.y -= 1;
while (n_threads.x * (n_threads.y+1) <= NUM_VECTOR_OP_THREADS_PER_BLOCK) ++n_threads.y;
if (n_threads.y > CudaNdarray_HOST_DIMS(%(x)s)[2])
n_threads.y = CudaNdarray_HOST_DIMS(%(x)s)[2];
while (n_threads.x * n_threads.y * n_threads.z < NUM_VECTOR_OP_THREADS_PER_BLOCK) ++n_threads.z;
n_threads.z -= 1;
while (n_threads.x * n_threads.y * (n_threads.z+1) <= NUM_VECTOR_OP_THREADS_PER_BLOCK) ++n_threads.z;
if (n_threads.z > 64)
n_threads.z = 64;
if (n_threads.z > CudaNdarray_HOST_DIMS(%(x)s)[0])
......@@ -1058,41 +1057,12 @@ class GpuSum(Op):
dim3 n_blocks(CudaNdarray_HOST_DIMS(%(x)s)[1]);
if (verbose) printf("running kernel_reduce_sum_1011_%(name)s\\n");
if (verbose) fprint_CudaNdarray(stdout, %(x)s);
if (verbose) fprint_CudaNdarray(stdout, %(z)s);
int n_shared = sizeof(float) * n_threads.x * n_threads.y * n_threads.z;
kernel_reduce_sum_1011_%(name)s<<<n_blocks, n_threads, n_shared>>>(
CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1],
CudaNdarray_HOST_DIMS(%(x)s)[2],
CudaNdarray_HOST_DIMS(%(x)s)[3],
CudaNdarray_DEV_DATA(%(x)s),
CudaNdarray_HOST_STRIDES(%(x)s)[0],
CudaNdarray_HOST_STRIDES(%(x)s)[1],
CudaNdarray_HOST_STRIDES(%(x)s)[2],
CudaNdarray_HOST_STRIDES(%(x)s)[3],
CudaNdarray_DEV_DATA(%(z)s),
CudaNdarray_HOST_STRIDES(%(z)s)[0]);
CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error: %%s: %%s. (grid: %%i x %%i; block: %%i x %%i x %%i)\\n",
"kernel_reduce_sum_1011_%(name)s",
cudaGetErrorString(sts),
n_blocks.x,
n_blocks.y,
n_threads.x,
n_threads.y,
n_threads.z);
%(fail)s;
}
%(makecall)s
}
""" %locals()
def c_code_cache_version(self):
return (14,)
return (15,)
def c_support_code_apply(self, node, nodename):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论