提交 ec8100d0 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

marked c_code_reduce_01X as ready for general ops

上级 09a71d05
...@@ -1035,14 +1035,15 @@ class GpuCAReduce(GpuOp): ...@@ -1035,14 +1035,15 @@ class GpuCAReduce(GpuOp):
:param N: the number of 1 in the pattern N=1 -> 01, N=2 -> 011 N=3 ->0111 :param N: the number of 1 in the pattern N=1 -> 01, N=2 -> 011 N=3 ->0111
Work for N=1,2,3 Work for N=1,2,3
""" """
assert N in [1, 2, 3] assert N in [1, 2, 3]
makecall = self._makecall(node, name, x, z, fail) makecall = self._makecall(node, name, x, z, fail)
self._op_guard()
N_pattern = ''.join(['1'] * N) N_pattern = ''.join(['1'] * N)
param_dim = ",".join(["CudaNdarray_HOST_DIMS(%(x)s)[%(i)s]" % locals() param_dim = ",".join(["CudaNdarray_HOST_DIMS(%(x)s)[%(i)s]" % locals()
for i in xrange(N + 1)]) for i in xrange(N + 1)])
strides_dim = ",".join(["CudaNdarray_HOST_STRIDES(%(x)s)[%(i)s]" strides_dim = ",".join(["CudaNdarray_HOST_STRIDES(%(x)s)[%(i)s]"
% locals() for i in xrange(N + 1)]) % locals() for i in xrange(N + 1)])
threads_y = """ threads_y = """
//get as many y threads as we can fit //get as many y threads as we can fit
while (n_threads.x * (n_threads.y+1) <= NUM_VECTOR_OP_THREADS_PER_BLOCK) while (n_threads.x * (n_threads.y+1) <= NUM_VECTOR_OP_THREADS_PER_BLOCK)
...@@ -1051,8 +1052,8 @@ class GpuCAReduce(GpuOp): ...@@ -1051,8 +1052,8 @@ class GpuCAReduce(GpuOp):
n_threads.y += 1; n_threads.y += 1;
else else
break; break;
} }""" % locals()
""" % locals()
threads_z = """ threads_z = """
//get as many z threads as we can fit //get as many z threads as we can fit
while (n_threads.x * n_threads.y * (n_threads.z+1) <= NUM_VECTOR_OP_THREADS_PER_BLOCK) while (n_threads.x * n_threads.y * (n_threads.z+1) <= NUM_VECTOR_OP_THREADS_PER_BLOCK)
...@@ -1061,13 +1062,15 @@ class GpuCAReduce(GpuOp): ...@@ -1061,13 +1062,15 @@ class GpuCAReduce(GpuOp):
n_threads.z += 1; n_threads.z += 1;
else else
break; break;
} }""" % locals()
""" % locals()
if len(self.reduce_mask) == 2: if len(self.reduce_mask) == 2:
threads_y = '' threads_y = ''
threads_z = '' threads_z = ''
if len(self.reduce_mask) == 3: if len(self.reduce_mask) == 3:
threads_z = '' threads_z = ''
print >> sio, """ print >> sio, """
{ {
int verbose = 0; int verbose = 0;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论