提交 ac703c5b authored 作者: James Bergstra's avatar James Bergstra

converted to switch structure on nd_collapse

上级 7a6b5233
...@@ -455,9 +455,12 @@ class NaiveAlgo(object): ...@@ -455,9 +455,12 @@ class NaiveAlgo(object):
#print >> sio, "\t,", "float * o%i_data" % ipos #print >> sio, "\t,", "float * o%i_data" % ipos
print >> sio, "\t)\n{" print >> sio, "\t)\n{"
use_shared_stride = True # TODO: Setting these to true makes the function fail SOMETIMES. I don't know why yet.
use_shared_stride = False
use_shared_limits = False
def decl_limits(nd): def decl_limits(nd):
if use_shared_limits:
print >> sio, "__shared__ float * limits[%(nd)s];" % locals() print >> sio, "__shared__ float * limits[%(nd)s];" % locals()
def stride(io, p, d): def stride(io, p, d):
...@@ -465,6 +468,11 @@ class NaiveAlgo(object): ...@@ -465,6 +468,11 @@ class NaiveAlgo(object):
return "s%s_str[%i][%i]" %(io, p, d) return "s%s_str[%i][%i]" %(io, p, d)
else: else:
return "%s%i_str_%i" %(io, p, d) return "%s%i_str_%i" %(io, p, d)
def limits(d):
if use_shared_limits:
return "limits[%i]" % d
else:
return "limits%i" % d
def decl_shared_stride(nin, nout, nd): def decl_shared_stride(nin, nout, nd):
if not use_shared_stride: if not use_shared_stride:
...@@ -484,13 +492,21 @@ class NaiveAlgo(object): ...@@ -484,13 +492,21 @@ class NaiveAlgo(object):
def calc_limit(d): def calc_limit(d):
s = stride('o', 0, d) s = stride('o', 0, d)
lname = limits(d)
if use_shared_limits:
print >> sio, "if ((threadIdx.x == 0) && (threadIdx.y == 0)) {" print >> sio, "if ((threadIdx.x == 0) && (threadIdx.y == 0)) {"
if d == 0: if d == 0:
print >> sio, "limits[%(d)s] = o0_data_0 + dim%(d)s * %(s)s;" % locals() print >> sio, "%(lname)s = o0_data_0 + dim%(d)s * %(s)s;" % locals()
else: else:
dm1 = d - 1 dm1 = d - 1
print >> sio, "limits[%(d)s] = o0_data_%(dm1)s + dim%(d)s * %(s)s;" % locals() print >> sio, "%(lname)s = o0_data_%(dm1)s + dim%(d)s * %(s)s;" % locals()
print >> sio, "} __syncthreads();" print >> sio, "} __syncthreads();"
else:
if d == 0:
print >> sio, "const float * %(lname)s = o0_data_0 + dim%(d)s * %(s)s;" % locals()
else:
dm1 = d - 1
print >> sio, "const float * %(lname)s = o0_data_%(dm1)s + dim%(d)s * %(s)s;" % locals()
def decl_ptrs(d, offset): def decl_ptrs(d, offset):
dm1 = d - 1 dm1 = d - 1
...@@ -511,7 +527,8 @@ class NaiveAlgo(object): ...@@ -511,7 +527,8 @@ class NaiveAlgo(object):
print >> sio, "o%(i)s_data_%(d)s += %(amt)s * %(s)s;" %locals() print >> sio, "o%(i)s_data_%(d)s += %(amt)s * %(s)s;" %locals()
def while_limit(d): def while_limit(d):
print >> sio, "while (o0_data_%(d)s < limits[%(d)s]) { " % locals() lname = limits(d)
print >> sio, "while (o0_data_%(d)s < %(lname)s) { " % locals()
def end_while(d): def end_while(d):
print >> sio, "}" print >> sio, "}"
...@@ -623,23 +640,27 @@ class NaiveAlgo(object): ...@@ -623,23 +640,27 @@ class NaiveAlgo(object):
sio = StringIO.StringIO() sio = StringIO.StringIO()
print >> sio, """ print >> sio, """
static inline bool static inline int
_is_c_contiguous_%(nodename)s(const int nd, const int * dims, const int * strides) c_contiguous_beyond_%(nodename)s(int nd, const int * dims, const int * strides, int &size)
{ {
bool c_contiguous = true; // return the dimension such that it and all greater dimensions are c-contiguous
int size = 1; // if everything is c_contiguous then this function returns 0, and size is left
for (int i = nd-1; (i >= 0) and c_contiguous; --i) // with the number of elements.
size = 1;
while (nd > 0)
{ {
if (dims[i] == 1) if ((dims[nd-1] > 1) && (strides[nd-1] != size))
continue;
if (strides[i] != size)
{ {
c_contiguous = false; return nd;
} }
size = size * dims[i]; size = size * dims[nd-1];
--nd;
} }
return c_contiguous; return nd;
} }
""" %locals()
print >> sio, """
static int callkernel_%(nodename)s(unsigned int numEls, const int d, static int callkernel_%(nodename)s(unsigned int numEls, const int d,
const int * dims, const int * dims,
%(input_params)s, %(input_params)s,
...@@ -658,19 +679,44 @@ class NaiveAlgo(object): ...@@ -658,19 +679,44 @@ class NaiveAlgo(object):
""" %locals() + " << ' ' << ".join(["i%i_data"%ipos] """ %locals() + " << ' ' << ".join(["i%i_data"%ipos]
+ list("i%i_str[%i]"%(ipos, di) for di in xrange(nd))) + ''' << "\\n"; ''' + list("i%i_str[%i]"%(ipos, di) for di in xrange(nd))) + ''' << "\\n"; '''
# Try to launch the Ccontiguous version # collapse contiguous right-most dimensions (ignoring scalars)
# this is a good idea because [we assume that] the output has been allocated c_contiguous
print >> sio, "int nd_collapse = 0;" #because the outputs are assumed to be c_contiguous
print >> sio, "int nd_collapse_size = numEls;" #because the outputs are assumed to be c_contiguous
for ipos in xrange(len(node.inputs)):
print >> sio, """
int nd_collapse_size_%(ipos)s;
int nd_collapse_%(ipos)s = c_contiguous_beyond_%(nodename)s(%(nd)s, dims, i%(ipos)s_str, nd_collapse_size_%(ipos)s);
if (nd_collapse_%(ipos)s > nd_collapse)
{
nd_collapse = nd_collapse_%(ipos)s;
nd_collapse_size = nd_collapse_size_%(ipos)s;
}
""" %locals()
# DEBUGPRINT
print >> sio, 'std::cerr << " nd_collapse " << nd_collapse << " " << nd_collapse_size << "\\n";'
for ipos in xrange(len(node.inputs)):
print >> sio, "int local_i%(ipos)s_str[%(nd)s];"%locals()
for d in xrange(nd):
print >> sio, "local_i%(ipos)s_str[%(d)s] = (%(d)s == nd_collapse) ? 1 : i%(ipos)s_str[%(d)s];"%locals()
for ipos in xrange(len(node.outputs)):
print >> sio, "int local_o%(ipos)s_str[%(nd)s];"%locals()
for d in xrange(nd):
print >> sio, "local_o%(ipos)s_str[%(d)s] = (%(d)s == nd_collapse) ? 1 : o%(ipos)s_str[%(d)s];"%locals()
print >> sio, "int local_dims[%(nd)s];"%locals()
for d in xrange(nd):
print >> sio, "local_dims[%(d)s] = (%(d)s == nd_collapse) ? nd_collapse_size : dims[%(d)s];"%locals()
def launch_Ccontiguous(nodename, id_self, scalar_op):
kernel_call_args = ["numEls"] kernel_call_args = ["numEls"]
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
kernel_call_args.append("i%i_data"%ipos) kernel_call_args.append("i%i_data"%ipos)
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
kernel_call_args.append("o%i_data"%ipos) kernel_call_args.append("o%i_data"%ipos)
kernel_call_args = ", ".join(kernel_call_args) kernel_call_args = ", ".join(kernel_call_args)
print >> sio, " if (" \
+ " && ".join(["_is_c_contiguous_%s(%i, dims, i%i_str)" % (nodename, nd, ipos) for ipos in xrange(len(node.inputs))]) \
+ ')'
print >> sio, """ print >> sio, """
{
std::cerr << " Running Ccontiguous version \\n";
int threads_per_block = std::min(numEls, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK); int threads_per_block = std::min(numEls, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
int n_blocks = std::min(numEls/threads_per_block + (numEls %% threads_per_block?1:0), (unsigned int)NUM_VECTOR_OP_BLOCKS); int n_blocks = std::min(numEls/threads_per_block + (numEls %% threads_per_block?1:0), (unsigned int)NUM_VECTOR_OP_BLOCKS);
kernel_%(scalar_op)s_%(nodename)s_Ccontiguous<<<n_blocks, threads_per_block>>>(%(kernel_call_args)s); kernel_%(scalar_op)s_%(nodename)s_Ccontiguous<<<n_blocks, threads_per_block>>>(%(kernel_call_args)s);
...@@ -685,34 +731,11 @@ class NaiveAlgo(object): ...@@ -685,34 +731,11 @@ class NaiveAlgo(object):
} }
return 0; return 0;
}
""" %locals() """ %locals()
def launch_tile4():
# if (False and nd == 4): # tiling kernel
# Try to launch a general version
#
# kernel_call_args are used to invoke the cuda kernel
kernel_call_args = ["numEls"]
kernel_call_args.extend("dims[%i]"%di for di in xrange(nd))
for ipos in xrange(len(node.inputs)):
kernel_call_args.append(
", ".join(["i%i_data"%ipos] + list("i%i_str[%i]"%(ipos, di) for di in xrange(nd)))
)
#strides = ", ".join("i%i_str[%i]"%(ipos, di) for di in xrange(nd))
#kernel_call_args.append( "%s, i%i_data" % (strides, ipos))
for ipos in xrange(len(node.outputs)):
kernel_call_args.append(
", ".join(["o%i_data"%ipos] + list("o%i_str[%i]"%(ipos, di) for di in xrange(nd)))
)
#strides = ", ".join("o%i_str[%i]"%(ipos, di) for di in xrange(nd))
#kernel_call_args.append( "%s, o%i_data" % (strides, ipos))
kernel_call_args = ", ".join(kernel_call_args)
if (nd == 4): # tiling kernel
print >> sio, """ print >> sio, """
else
{ {
std::cerr << " Running tiling 4D \\n"; std::cerr << " Running tiling 4D \\n";
dim3 gridDim(dims[0], dims[1]); dim3 gridDim(dims[0], dims[1]);
...@@ -756,10 +779,25 @@ class NaiveAlgo(object): ...@@ -756,10 +779,25 @@ class NaiveAlgo(object):
} }
} }
""" %locals() """ %locals()
else:
def launch_General(nodename, id_self, scalar_op):
# kernel_call_args are used to invoke the cuda kernel
kernel_call_args = ["numEls"]
kernel_call_args.extend("dims[%i]"%di for di in xrange(nd))
for ipos in xrange(len(node.inputs)):
kernel_call_args.append(
", ".join(["i%i_data"%ipos] + list("local_i%i_str[%i]"%(ipos, di) for di in xrange(nd)))
)
#strides = ", ".join("i%i_str[%i]"%(ipos, di) for di in xrange(nd))
#kernel_call_args.append( "%s, i%i_data" % (strides, ipos))
for ipos in xrange(len(node.outputs)):
kernel_call_args.append(
", ".join(["o%i_data"%ipos] + list("o%i_str[%i]"%(ipos, di) for di in xrange(nd)))
)
#strides = ", ".join("o%i_str[%i]"%(ipos, di) for di in xrange(nd))
#kernel_call_args.append( "%s, o%i_data" % (strides, ipos))
kernel_call_args = ", ".join(kernel_call_args)
print >> sio, """ print >> sio, """
else
{
std::cerr << " Running general version \\n"; std::cerr << " Running general version \\n";
int threads_per_block = std::min(numEls, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK); int threads_per_block = std::min(numEls, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
int n_blocks = std::min(numEls/threads_per_block + (numEls %% threads_per_block?1:0), (unsigned int)NUM_VECTOR_OP_BLOCKS); int n_blocks = std::min(numEls/threads_per_block + (numEls %% threads_per_block?1:0), (unsigned int)NUM_VECTOR_OP_BLOCKS);
...@@ -773,12 +811,25 @@ class NaiveAlgo(object): ...@@ -773,12 +811,25 @@ class NaiveAlgo(object):
} }
return 0; return 0;
}
}
""" %locals() """ %locals()
print >> sio, "switch (nd_collapse) {"
print >> sio, "case 0: {"
launch_Ccontiguous(nodename, id_self, scalar_op)
print >> sio, " } break;"
#print >> sio, "case 4: {"
#launch_tile4()
#print >> sio, " } break;"
print >> sio, "default: {"
launch_General(nodename, id_self, scalar_op)
print >> sio, " }"
print >> sio, "}"
print >> sio, "}"
#N.B. cudaGetLastError is called by c_code #N.B. cudaGetLastError is called by c_code
return sio.getvalue() return sio.getvalue()
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
return self.c_src_kernel(node, nodename) \ return self.c_src_kernel(node, nodename) \
+ self.c_src_kernel_Ccontiguous(node, nodename) \ + self.c_src_kernel_Ccontiguous(node, nodename) \
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论