提交 5dbb6809 authored 作者: Frederic Bastien's avatar Frederic Bastien

make sur we select the good value for n_thread.{y,z} for the new GpuSum pattern

上级 e39495dd
......@@ -784,7 +784,7 @@ class GpuSum(Op):
strides_dim = ",".join(["CudaNdarray_HOST_STRIDES(%(x)s)[%(i)s]"%locals() for i in range(N+1)])
threads_y = """
//get as many y threads as we can fit
while (n_threads.x * n_threads.y < NUM_VECTOR_OP_THREADS_PER_BLOCK)
while (n_threads.x * (n_threads.y+1) <= NUM_VECTOR_OP_THREADS_PER_BLOCK)
{
if (n_threads.y < CudaNdarray_HOST_DIMS(%(x)s)[%(N)s-1])
n_threads.y += 1;
......@@ -794,13 +794,13 @@ class GpuSum(Op):
"""%locals()
threads_z = """
//get as many z threads as we can fit
while (n_threads.x * n_threads.y * n_threads.z <= NUM_VECTOR_OP_THREADS_PER_BLOCK)
while (n_threads.x * n_threads.y * (n_threads.z+1) <= NUM_VECTOR_OP_THREADS_PER_BLOCK)
{
if (n_threads.z > CudaNdarray_HOST_DIMS(%(x)s)[%(N)s-2])
if (n_threads.z < CudaNdarray_HOST_DIMS(%(x)s)[%(N)s-2])
n_threads.z += 1;
else
break;
n_threads.z += 1;
}
n_threads.z -= 1;
"""%locals()
if len(self.reduce_mask)==2:
threads_y = ''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论