提交 25df3acb authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

elemwise.py has been modified in order to respect the flake8 style.

elemwise.py do not contain long lines.
上级 27c8da22
...@@ -97,7 +97,7 @@ class NaiveAlgo(object): ...@@ -97,7 +97,7 @@ class NaiveAlgo(object):
self.scalar_op.__class__.__name__, nodename, nd), file=sio) self.scalar_op.__class__.__name__, nodename, nd), file=sio)
if (nd): if (nd):
print("\t,", ", ".join("const int dim%i" % i print("\t,", ", ".join("const int dim%i" % i
for i in xrange(nd)), file=sio) for i in xrange(nd)), file=sio)
# declare inputs # declare inputs
for ipos, i in enumerate(node.inputs): for ipos, i in enumerate(node.inputs):
s = ", ".join(["const float * i%i_data" % ipos] + s = ", ".join(["const float * i%i_data" % ipos] +
...@@ -108,8 +108,8 @@ class NaiveAlgo(object): ...@@ -108,8 +108,8 @@ class NaiveAlgo(object):
s = ", ".join(["float * o%i_data" % ipos] + s = ", ".join(["float * o%i_data" % ipos] +
["int o%i_str_%i" % (ipos, d) for d in xrange(nd)]) ["int o%i_str_%i" % (ipos, d) for d in xrange(nd)])
print("\t,", s, file=sio) print("\t,", s, file=sio)
#print >> sio, "\t,", ", ".join("int o%i_str_%i" % (ipos, d) for d in xrange(nd)) # print >> sio, "\t,", ", ".join("int o%i_str_%i" % (ipos, d) for d in xrange(nd))
#print >> sio, "\t,", "float * o%i_data" % ipos # print >> sio, "\t,", "float * o%i_data" % ipos
print("\t)\n{", file=sio) print("\t)\n{", file=sio)
print(" const int idx = blockIdx.x * blockDim.x + threadIdx.x;", file=sio) print(" const int idx = blockIdx.x * blockDim.x + threadIdx.x;", file=sio)
print(" const int numThreads = blockDim.x * gridDim.x;", file=sio) print(" const int numThreads = blockDim.x * gridDim.x;", file=sio)
...@@ -129,7 +129,7 @@ class NaiveAlgo(object): ...@@ -129,7 +129,7 @@ class NaiveAlgo(object):
print(" const float * ii_i%i_data = i%i_data;" % (ipos, ipos), file=sio) print(" const float * ii_i%i_data = i%i_data;" % (ipos, ipos), file=sio)
for ipos, i in enumerate(node.outputs): for ipos, i in enumerate(node.outputs):
print(" float * ii_o%i_data = o%i_data;" % (ipos, ipos), file=sio) print(" float * ii_o%i_data = o%i_data;" % (ipos, ipos), file=sio)
for d in xrange(nd-1, -1, -1): for d in xrange(nd - 1, -1, -1):
if d > 0: if d > 0:
print(" int pos%i = ii %% dim%i;" % (d, d), file=sio) print(" int pos%i = ii %% dim%i;" % (d, d), file=sio)
print(" ii = ii / dim%i;" % d, file=sio) print(" ii = ii / dim%i;" % d, file=sio)
...@@ -161,9 +161,9 @@ class NaiveAlgo(object): ...@@ -161,9 +161,9 @@ class NaiveAlgo(object):
print("ii_o%i_data[0] = o%i_i;" % (ipos, ipos), file=sio) print("ii_o%i_data[0] = o%i_i;" % (ipos, ipos), file=sio)
print(" }", file=sio) print(" }", file=sio)
#indent = " "*(4*d+7) # indent = " "*(4*d+7)
# for ipos, i in enumerate(node.inputs): # for ipos, i in enumerate(node.inputs):
#print >> sio, indent, "const float * i%i" % ipos, '= i%i_data', '' # print >> sio, indent, "const float * i%i" % ipos, '= i%i_data', ''
print("}", file=sio) print("}", file=sio)
# print sio.getvalue() # print sio.getvalue()
...@@ -211,10 +211,11 @@ class NaiveAlgo(object): ...@@ -211,10 +211,11 @@ class NaiveAlgo(object):
print("// Input ", ipos, str(i.type), file=sio) print("// Input ", ipos, str(i.type), file=sio)
for ipos, i in enumerate(node.outputs): for ipos, i in enumerate(node.outputs):
print("// Output ", ipos, str(i.type), file=sio) print("// Output ", ipos, str(i.type), file=sio)
print("static __global__ void kernel_%s_%s_%s(unsigned int numEls" % ( print(
self.scalar_op.__class__.__name__, "static __global__ void kernel_%s_%s_%s(unsigned int numEls" %
nodename, (self.scalar_op.__class__.__name__,
'tiling%i'%nd), file=sio) nodename,
'tiling%i' % nd), file=sio)
if (nd): if (nd):
print("\t,", ", ".join("const int dim%i" % i for i in xrange(nd)), file=sio) print("\t,", ", ".join("const int dim%i" % i for i in xrange(nd)), file=sio)
# declare inputs # declare inputs
...@@ -225,15 +226,15 @@ class NaiveAlgo(object): ...@@ -225,15 +226,15 @@ class NaiveAlgo(object):
for ipos, i in enumerate(node.outputs): for ipos, i in enumerate(node.outputs):
s = ", ".join(["float * o%i_data" % ipos] + list("int o%i_str_%i" % (ipos, d) for d in xrange(nd))) s = ", ".join(["float * o%i_data" % ipos] + list("int o%i_str_%i" % (ipos, d) for d in xrange(nd)))
print("\t,", s, file=sio) print("\t,", s, file=sio)
#print >> sio, "\t,", ", ".join("int o%i_str_%i" % (ipos, d) for d in xrange(nd)) # print >> sio, "\t,", ", ".join("int o%i_str_%i" % (ipos, d) for d in xrange(nd))
#print >> sio, "\t,", "float * o%i_data" % ipos # print >> sio, "\t,", "float * o%i_data" % ipos
print("\t)\n{", file=sio) print("\t)\n{", file=sio)
# For each input that is a scalar which has been broadcasted to a tensor, # For each input that is a scalar which has been broadcasted to a tensor,
# load it into a local variable # load it into a local variable
print(" __shared__ float value0[%i];" % len(node.inputs), file=sio) print(" __shared__ float value0[%i];" % len(node.inputs), file=sio)
print(" __shared__ int shared_dims[%(nd)s];" % locals(), file=sio) print(" __shared__ int shared_dims[%(nd)s];" % locals(), file=sio)
#print >> sio, " __shared__ int shared_i_str[%(n_in)s][%(nd)s]" # print >> sio, " __shared__ int shared_i_str[%(n_in)s][%(nd)s]"
print(" if ((threadIdx.x == 0) && (threadIdx.y == 0)) {", file=sio) print(" if ((threadIdx.x == 0) && (threadIdx.y == 0)) {", file=sio)
for ipos, i in enumerate(node.inputs): for ipos, i in enumerate(node.inputs):
if _logical_scalar(i): if _logical_scalar(i):
...@@ -274,15 +275,18 @@ class NaiveAlgo(object): ...@@ -274,15 +275,18 @@ class NaiveAlgo(object):
# perform the scalar operation on the input and output references # perform the scalar operation on the input and output references
# TODO: What if the scalar_op needs support_code?? # TODO: What if the scalar_op needs support_code??
task_code = self.scalar_op.c_code( task_code = self.scalar_op.c_code(
Apply(self.scalar_op, Apply(
[scalar.Scalar(dtype=input.type.dtype).make_variable() self.scalar_op,
for input in node.inputs], [scalar.Scalar(
[scalar.Scalar(dtype=output.type.dtype).make_variable() dtype=input.type.dtype).make_variable()
for output in node.outputs]) for input in node.inputs],
, nodename + '_scalar_' [scalar.Scalar(
, get_str_list_logical_scalar(node, value_str='value0[%i]') dtype=output.type.dtype).make_variable()
, ['ii_o%i_data[0]'%ipos for ipos, i in enumerate(node.outputs)] for output in node.outputs]),
, sub=dict(fail='return;')) # TODO: set a failure code somehow!!! nodename + '_scalar_',
get_str_list_logical_scalar(node, value_str='value0[%i]'),
['ii_o%i_data[0]' % ipos for ipos, i in enumerate(node.outputs)],
sub=dict(fail='return;')) # TODO: set a failure code somehow!!!
print(" ", task_code, file=sio) print(" ", task_code, file=sio)
print(" }" * nd, file=sio) print(" }" * nd, file=sio)
...@@ -290,9 +294,9 @@ class NaiveAlgo(object): ...@@ -290,9 +294,9 @@ class NaiveAlgo(object):
# TODO: insert runtime stride checks that select the best loop order either here, or in # TODO: insert runtime stride checks that select the best loop order either here, or in
# the host code that launched the kernel (host code probably better spot) # the host code that launched the kernel (host code probably better spot)
#indent = " "*(4*d+7) # indent = " "*(4*d+7)
# for ipos, i in enumerate(node.inputs): # for ipos, i in enumerate(node.inputs):
#print >> sio, indent, "const float * i%i" % ipos, '= i%i_data', '' # print >> sio, indent, "const float * i%i" % ipos, '= i%i_data', ''
print("}", file=sio) print("}", file=sio)
print(sio.getvalue()) print(sio.getvalue())
...@@ -319,10 +323,11 @@ class NaiveAlgo(object): ...@@ -319,10 +323,11 @@ class NaiveAlgo(object):
print("// Input ", ipos, str(i.type), file=sio) print("// Input ", ipos, str(i.type), file=sio)
for ipos, i in enumerate(node.outputs): for ipos, i in enumerate(node.outputs):
print("// Output ", ipos, str(i.type), file=sio) print("// Output ", ipos, str(i.type), file=sio)
print("static __global__ void kernel_%s_%s_%s(unsigned int numEls" % ( print(
self.scalar_op.__class__.__name__, "static __global__ void kernel_%s_%s_%s(unsigned int numEls" %
nodename, (self.scalar_op.__class__.__name__,
'tiling%i_less_registers'%nd), file=sio) nodename,
'tiling%i_less_registers' % nd), file=sio)
if (nd): if (nd):
print("\t,", ", ".join("const int dim%i" % i for i in xrange(nd)), file=sio) print("\t,", ", ".join("const int dim%i" % i for i in xrange(nd)), file=sio)
# declare inputs # declare inputs
...@@ -333,8 +338,8 @@ class NaiveAlgo(object): ...@@ -333,8 +338,8 @@ class NaiveAlgo(object):
for ipos, i in enumerate(node.outputs): for ipos, i in enumerate(node.outputs):
s = ", ".join(["float * o%i_data_0" % ipos] + list("int o%i_str_%i" % (ipos, d) for d in xrange(nd))) s = ", ".join(["float * o%i_data_0" % ipos] + list("int o%i_str_%i" % (ipos, d) for d in xrange(nd)))
print("\t,", s, file=sio) print("\t,", s, file=sio)
#print >> sio, "\t,", ", ".join("int o%i_str_%i" % (ipos, d) for d in xrange(nd)) # print >> sio, "\t,", ", ".join("int o%i_str_%i" % (ipos, d) for d in xrange(nd))
#print >> sio, "\t,", "float * o%i_data" % ipos # print >> sio, "\t,", "float * o%i_data" % ipos
print("\t)\n{", file=sio) print("\t)\n{", file=sio)
# TODO: Setting these to true makes the function fail SOMETIMES. I don't know why yet. # TODO: Setting these to true makes the function fail SOMETIMES. I don't know why yet.
...@@ -350,6 +355,7 @@ class NaiveAlgo(object): ...@@ -350,6 +355,7 @@ class NaiveAlgo(object):
return "s%s_str[%i][%i]" % (io, p, d) return "s%s_str[%i][%i]" % (io, p, d)
else: else:
return "%s%i_str_%i" % (io, p, d) return "%s%i_str_%i" % (io, p, d)
def limits(d): def limits(d):
if use_shared_limits: if use_shared_limits:
return "limits[%i]" % d return "limits[%i]" % d
...@@ -417,15 +423,19 @@ class NaiveAlgo(object): ...@@ -417,15 +423,19 @@ class NaiveAlgo(object):
def task_code(d): def task_code(d):
print(self.scalar_op.c_code( print(self.scalar_op.c_code(
Apply(self.scalar_op, Apply(
self.scalar_op,
[scalar.Scalar(dtype=input.type.dtype).make_variable() [scalar.Scalar(dtype=input.type.dtype).make_variable()
for input in node.inputs], for input in node.inputs],
[scalar.Scalar(dtype=output.type.dtype).make_variable() [scalar.Scalar(dtype=output.type.dtype).make_variable()
for output in node.outputs]) for output in node.outputs]),
, nodename + '_scalar_' nodename + '_scalar_',
, ['i%i_data_%i[0]'%(ipos, d) for ipos, i in enumerate(node.inputs)] ['i%i_data_%i[0]' % (ipos, d) for ipos,
, ['o%i_data_%i[0]'%(ipos, d) for ipos, i in enumerate(node.outputs)] i in enumerate(node.inputs)],
, sub=dict(fail='return;')), file=sio) # TODO: set a failure code somehow!!! ['o%i_data_%i[0]' % (ipos, d) for ipos,
i in enumerate(node.outputs)],
sub=dict(fail='return;')), file=sio)
# TODO: set a failure code somehow!!!
if nd == 4: if nd == 4:
decl_shared_stride(n_in, n_out, nd) decl_shared_stride(n_in, n_out, nd)
...@@ -495,16 +505,19 @@ class NaiveAlgo(object): ...@@ -495,16 +505,19 @@ class NaiveAlgo(object):
for ipos, i in enumerate(node.outputs): for ipos, i in enumerate(node.outputs):
print("npy_%s o%d_i;" % (i.dtype, ipos), file=sio) print("npy_%s o%d_i;" % (i.dtype, ipos), file=sio)
task_code = self.scalar_op.c_code( task_code = self.scalar_op.c_code(
Apply(self.scalar_op, Apply(
[scalar.Scalar(dtype=input.type.dtype).make_variable() self.scalar_op,
for input in node.inputs], [scalar.Scalar(dtype=input.type.dtype).make_variable()
[scalar.Scalar(dtype=output.type.dtype).make_variable() for input in node.inputs],
for output in node.outputs]) [scalar.Scalar(dtype=output.type.dtype).make_variable()
, nodename + '_scalar_' for output in node.outputs]),
#, ['i%i_data[i]'%ipos for ipos, i in enumerate(node.inputs)] nodename + '_scalar_',
, get_str_list_logical_scalar(node, data_str='i%i_data[i]') # , ['i%i_data[i]'%ipos for ipos,
, ['o%i_i'%ipos for ipos, i in enumerate(node.outputs)] # i in enumerate(node.inputs)]
, sub=dict(fail='return;')) # TODO: set a failure code somehow!!! get_str_list_logical_scalar(node, data_str='i%i_data[i]'),
['o%i_i' % ipos for ipos, i in enumerate(node.outputs)],
sub=dict(fail='return;'))
# TODO: set a failure code somehow!!!
print(" ", task_code, file=sio) print(" ", task_code, file=sio)
for ipos, _ in enumerate(node.outputs): for ipos, _ in enumerate(node.outputs):
print("o%i_data[i] = o%i_i;" % (ipos, ipos), file=sio) print("o%i_data[i] = o%i_i;" % (ipos, ipos), file=sio)
...@@ -539,18 +552,21 @@ class NaiveAlgo(object): ...@@ -539,18 +552,21 @@ class NaiveAlgo(object):
nb_outputs = len(node.outputs) nb_outputs = len(node.outputs)
d = dict() d = dict()
# input_params and output_params go into the function declaration/definition # input_params and output_params go into the function declaration/definition
input_params = ", ".join("const float * i%i_data, const int * i%i_str"%(ipos, ipos) input_params = ", ".join(
for ipos in xrange(len(node.inputs))) "const float * i%i_data, const int * i%i_str" % (ipos, ipos)
output_params = ", ".join("float * o%i_data, const int * o%i_str"%(ipos, ipos) for ipos in xrange(len(node.inputs)))
for ipos in xrange(len(node.outputs))) output_params = ", ".join(
"float * o%i_data, const int * o%i_str" % (ipos, ipos)
for ipos in xrange(len(node.outputs)))
# input_args and output_args go into the recursive call. # input_args and output_args go into the recursive call.
input_args = ", ".join("i%i_data, i%i_str"%(ipos, ipos) input_args = ", ".join("i%i_data, i%i_str" % (ipos, ipos)
for ipos in xrange(len(node.inputs))) for ipos in xrange(len(node.inputs)))
output_args = ", ".join("o%i_data, o%i_str"%(ipos, ipos) output_args = ", ".join("o%i_data, o%i_str" % (ipos, ipos)
for ipos in xrange(len(node.outputs))) for ipos in xrange(len(node.outputs)))
prod_dims = '*'.join(["dims[%i]"%di for di in xrange(nd)]+['1']) prod_dims = '*'.join(
["dims[%i]" % di for di in xrange(nd)] + ['1'])
scalar_op = self.scalar_op.__class__.__name__ scalar_op = self.scalar_op.__class__.__name__
...@@ -578,20 +594,30 @@ class NaiveAlgo(object): ...@@ -578,20 +594,30 @@ class NaiveAlgo(object):
print(""" print("""
std::cerr << "calling kernel_%(scalar_op)s_%(nodename)s w numEls" << numEls << " dims"<< d << "\\n"; std::cerr << "calling kernel_%(scalar_op)s_%(nodename)s w numEls" << numEls << " dims"<< d << "\\n";
""" % locals(), file=sio) """ % locals(), file=sio)
print('std::cerr << ' + " << ' ' << ".join(['" "']+list("dims[%i]"%di print(
for di in xrange(nd)) + ["'\\n';"]), file=sio) 'std::cerr << ' + " << ' ' << ".join(
['" "'] +
list("dims[%i]" % di for di in xrange(nd)) +
["'\\n';"]),
file=sio)
if self.verbose > 1: if self.verbose > 1:
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
istrings = [
"i%s_str[%i]" % (ipos, di) for di in xrange(nd)]
ipositions = " << ' ' << ".join(
["i%s_data" % ipos] + istrings)
print(""" print("""
std::cerr << " %(ipos)s data strides" << std::cerr << " %(ipos)s data strides" << %(ipositions)s << "\\n";
""" % locals() + " << ' ' << ".join(["i%s_data"%ipos] """ % dict(ipos=ipos, ipositions=ipositions), file=sio)
+ list("i%s_str[%i]"%(ipos, di) for di in xrange(nd))) + ''' << "\\n"; ''', file=sio)
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
print(""" print("""
std::cerr << " %(ipos)s data strides" << std::cerr << " %(ipos)s data strides" <<
""" % locals() + " << ' ' << ".join(["o%s_data"%ipos] """ % locals() + " << ' ' << ".join(
+ list("o%s_str[%i]"%(ipos, di) for di in xrange(nd))) + ''' << "\\n"; ''', file=sio) ["o%s_data" % ipos] +
list(
"o%s_str[%i]" % (ipos, di) for di in xrange(nd)
)) +
''' << "\\n"; ''', file=sio)
# collapse dimension that are broadcast in all inputs. # collapse dimension that are broadcast in all inputs.
# need to be done before contiguous collapse as it will break it. # need to be done before contiguous collapse as it will break it.
# do the dimensions and the strides # do the dimensions and the strides
...@@ -636,11 +662,19 @@ class NaiveAlgo(object): ...@@ -636,11 +662,19 @@ class NaiveAlgo(object):
print('std::cerr << "\\n";', file=sio) print('std::cerr << "\\n";', file=sio)
if nd > 0: if nd > 0:
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
print('std::cerr << " local_str inputs %(ipos)s: " <<'%locals() + \ print(
' << " " << '.join(["local_str[%s][%s]" % (ipos, x) for x in xrange(nd)])+'<<"\\n";', file=sio) 'std::cerr << " local_str inputs %(ipos)s: " <<' % locals() +
' << " " << '.join(["local_str[%s][%s]" % (ipos, x)
for x in xrange(nd)]) +
'<<"\\n";', file=sio)
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
print('std::cerr << " local_ostr inputs %(ipos)s: " <<'%locals() + \ print(
' << " " << '.join(["local_ostr[%s][%s]" % (ipos, x) for x in xrange(nd)])+'<<"\\n";', file=sio) 'std::cerr << " local_ostr inputs %(ipos)s: " <<' %
locals() +
' << " " << '.join(
["local_ostr[%s][%s]" %
(ipos, x) for x in xrange(nd)]) +
'<<"\\n";', file=sio)
print(""" print("""
for(int id=0;id<nd_collapse;id++){ for(int id=0;id<nd_collapse;id++){
...@@ -668,35 +702,51 @@ class NaiveAlgo(object): ...@@ -668,35 +702,51 @@ class NaiveAlgo(object):
nd_collapse--; id--; nd_collapse--; id--;
} }
} }
"""%locals(), file=sio) """ % locals(), file=sio)
if self.verbose > 2: if self.verbose > 2:
print('std::cerr <<"after broadcast collapse\\n";', file=sio) print('std::cerr <<"after broadcast collapse\\n";', file=sio)
print('std::cerr<< "nd_collapse "<< nd_collapse << "\\n"; ', file=sio) print('std::cerr<< "nd_collapse "<< nd_collapse << "\\n"; ', file=sio)
print('std::cerr << "local_dims";', file=sio) print('std::cerr << "local_dims";', file=sio)
for d in xrange(nd): for d in xrange(nd):
print('std::cerr << " " << local_dims[%(d)s]; '%locals(), file=sio) print('std::cerr << " " << local_dims[%(d)s]; ' %
locals(), file=sio)
print('std::cerr << "\\n";', file=sio) print('std::cerr << "\\n";', file=sio)
if nd > 0: if nd > 0:
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
print('std::cerr << " local_str %(ipos)s: " <<'%locals()+' << " " << '.join(["local_str[%s][%s]" % (ipos, x) for x in xrange(nd)])+'<<"\\n";', file=sio) print('std::cerr << " local_str %(ipos)s: " <<' %
locals() + ' << " " << '.join(
["local_str[%s][%s]" %
(ipos, x) for x in xrange(nd)]) +
'<<"\\n";', file=sio)
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
print('std::cerr << " local_ostr %(ipos)s: " <<'%locals()+' << " " << '.join(["local_ostr[%s][%s]" % (ipos, x) for x in xrange(nd)])+'<<"\\n";', file=sio) print(
'std::cerr << " local_ostr %(ipos)s: " <<' %
locals() + ' << " " << '.join(
["local_ostr[%s][%s]" %
(ipos, x) for x in xrange(nd)]) +
'<<"\\n";', file=sio)
# collapse contiguous dimensions (ignoring scalars, generic version(collapse any dimensions, right, left, middle)) # collapse contiguous dimensions (ignoring scalars, generic version(collapse any dimensions, right, left, middle))
# this is a good idea because we make less index calculation in the gpu. # this is a good idea because we make less index calculation in the gpu.
if nd > 0: if nd > 0:
print("int nd_collapse_[%(nd)s] = {"%locals() + ','.join(['1' for x in xrange(nd)]) + "};", file=sio) print("int nd_collapse_[%(nd)s] = {" %
locals() + ','.join(
['1' for x in xrange(nd)]) + "};", file=sio)
else: else:
print("int *nd_collapse_ = NULL;", file=sio) print("int *nd_collapse_ = NULL;", file=sio)
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
if not _logical_scalar(node.inputs[ipos]): if not _logical_scalar(node.inputs[ipos]):
if nd > 0: if nd > 0:
print(""" print("""
int nd_collapse_%(ipos)s[%(nd)s] = {"""%locals() + ','.join(['1' for x in xrange(nd)]) + "};", file=sio) int nd_collapse_%(ipos)s[%(nd)s] = {""" %
locals() +
','.join(['1' for x in xrange(nd)]) +
"};", file=sio)
else: else:
print(""" print("""
int *nd_collapse_%(ipos)s = NULL;"""%locals(), file=sio) int * nd_collapse_%(ipos)s = NULL;""" %
locals(), file=sio)
print(""" print("""
can_collapse_%(nodename)s(nd_collapse, local_dims, local_str[%(ipos)s], nd_collapse_%(ipos)s); can_collapse_%(nodename)s(nd_collapse, local_dims, local_str[%(ipos)s], nd_collapse_%(ipos)s);
for(int i=0;i<nd_collapse;i++){ for(int i=0;i<nd_collapse;i++){
...@@ -707,8 +757,10 @@ nd_collapse_[i]=0; ...@@ -707,8 +757,10 @@ nd_collapse_[i]=0;
if self.verbose > 1: if self.verbose > 1:
print(""" print("""
std::cerr<< "nd_collapse_%(ipos)s "<< std::cerr<< "nd_collapse_%(ipos)s "<<
"""%locals(), file=sio) """ % locals(), file=sio)
print(' << " " << '.join(["nd_collapse_%s[" % ipos + str(i)+"]" for i in xrange(nd)]), file=sio) print(' << " " << '.join(["nd_collapse_ %s[" %
ipos + str(i) + "]" for i in xrange(nd)]),
file=sio)
print('<< "\\n";', file=sio) print('<< "\\n";', file=sio)
# update the local stride. # update the local stride.
...@@ -721,7 +773,7 @@ nd_collapse_[i]=0; ...@@ -721,7 +773,7 @@ nd_collapse_[i]=0;
local_str[%(ipos)s][j-1]=local_str[%(ipos)s][j]; local_str[%(ipos)s][j-1]=local_str[%(ipos)s][j];
} }
} }
"""%locals(), file=sio) """ % locals(), file=sio)
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
print(""" print("""
...@@ -732,7 +784,7 @@ nd_collapse_[i]=0; ...@@ -732,7 +784,7 @@ nd_collapse_[i]=0;
local_ostr[%(ipos)s][j-1]=local_ostr[%(ipos)s][j]; local_ostr[%(ipos)s][j-1]=local_ostr[%(ipos)s][j];
} }
} }
"""%locals(), file=sio) """ % locals(), file=sio)
# update the local dims. # update the local dims.
print(""" print("""
...@@ -743,16 +795,20 @@ nd_collapse_[i]=0; ...@@ -743,16 +795,20 @@ nd_collapse_[i]=0;
local_dims[j-1]=local_dims[j]; local_dims[j-1]=local_dims[j];
} }
} }
"""%locals(), file=sio) """ % locals(), file=sio)
# update the new number of dim # update the new number of dim
print(""" print("""
for(int i=1, end=nd_collapse;i<end;i++){ for(int i=1, end=nd_collapse;i<end;i++){
if(nd_collapse_[i]==1)nd_collapse--; if(nd_collapse_[i]==1)nd_collapse--;
} }
if(nd_collapse == 1 """%locals(), file=sio) if(nd_collapse == 1 """ % locals(), file=sio)
l = ["local_str[%s][nd_collapse-1]==1 "%ipos for ipos in xrange(len(node.inputs)) if not _logical_scalar(node.inputs[ipos])] l = ["local_str[%s][nd_collapse-1]==1 " %
l += ["local_ostr[%s][nd_collapse-1]==1 "%ipos for ipos in xrange(len(node.outputs)) if not _logical_scalar(node.outputs[ipos])] ipos for ipos in xrange(len(node.inputs)) if not
_logical_scalar(node.inputs[ipos])]
l += ["local_ostr[%s][nd_collapse-1]==1 " %
ipos for ipos in xrange(len(node.outputs)) if not
_logical_scalar(node.outputs[ipos])]
if len(l) > 0: if len(l) > 0:
print(" && ", " && ".join(l), file=sio) print(" && ", " && ".join(l), file=sio)
print("""){nd_collapse=0;} """, file=sio) print("""){nd_collapse=0;} """, file=sio)
...@@ -762,20 +818,31 @@ nd_collapse_[i]=0; ...@@ -762,20 +818,31 @@ nd_collapse_[i]=0;
print("""std::cerr << "nd_collapse " << nd_collapse << "\\n"; """ % locals(), file=sio) print("""std::cerr << "nd_collapse " << nd_collapse << "\\n"; """ % locals(), file=sio)
if self.verbose > 1: if self.verbose > 1:
for d in xrange(nd): for d in xrange(nd):
print('std::cerr << " " << local_dims[%(d)s]; '%locals(), file=sio) print('std::cerr << " " << local_dims[%(d)s]; ' %
locals(),
file=sio)
print('std::cerr << "\\n";', file=sio) print('std::cerr << "\\n";', file=sio)
if nd > 0: if nd > 0:
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
print('std::cerr << " local_str %(ipos)s: " <<'%locals()+' << " " << '.join(["local_str[%s][%s]"%(ipos, x) for x in xrange(nd)])+'<<"\\n";', file=sio) print(
'std::cerr << " local_str % (ipos)s: " <<' %
locals() + ' << " " << '.join(
["local_str[%s][%s]" %
(ipos, x) for x in xrange(nd)]) +
'<<"\\n";', file=sio)
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
print('std::cerr << " local_ostr %(ipos)s: " <<'%locals()+' << " " << '.join(["local_ostr[%s][%s]"%(ipos, x) for x in xrange(nd)])+'<<"\\n";', file=sio) print('std::cerr << " local_ostr % (ipos)s: " <<' %
locals() + ' << " " << '.join(
["local_ostr[%s][%s]" %
(ipos, x) for x in xrange(nd)]) +
'<<"\\n";', file=sio)
def launch_Ccontiguous(nodename, scalar_op, sync=True): def launch_Ccontiguous(nodename, scalar_op, sync=True):
kernel_call_args = ["numEls"] kernel_call_args = ["numEls"]
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
kernel_call_args.append("i%i_data"%ipos) kernel_call_args.append("i%i_data" % ipos)
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
kernel_call_args.append("o%i_data"%ipos) kernel_call_args.append("o%i_data" % ipos)
kernel_call_args = ", ".join(kernel_call_args) kernel_call_args = ", ".join(kernel_call_args)
verb = "" verb = ""
if self.verbose: if self.verbose:
...@@ -817,20 +884,27 @@ nd_collapse_[i]=0; ...@@ -817,20 +884,27 @@ nd_collapse_[i]=0;
# kernel_call_args are used to invoke the cuda kernel # kernel_call_args are used to invoke the cuda kernel
local = "local_" local = "local_"
kernel_call_args = ["numEls"] kernel_call_args = ["numEls"]
kernel_call_args.extend(local+"dims[%i]"%di for di in xrange(force_nd)) kernel_call_args.extend(
local + "dims[%i]" %
di for di in xrange(force_nd))
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
kernel_call_args += ["i%i_data"%ipos] + list(local+"str[%i][%i]"%(ipos, di) for di in xrange(force_nd)) kernel_call_args += ["i%i_data" % ipos] + list(
#strides = ", ".join("i%i_str[%i]"%(ipos, di) for di in xrange(force_nd)) local + "str[%i][%i]" %
#kernel_call_args.append( "%s, i%i_data" % (strides, ipos)) (ipos, di) for di in xrange(force_nd))
# strides = ", ".join("i%i_str[%i]"%(ipos, di) for di in xrange(force_nd))
# kernel_call_args.append( "%s, i%i_data" % (strides, ipos))
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
kernel_call_args += ["o%i_data"%ipos] + list(local+"ostr[%i][%i]"%(ipos, di) for di in xrange(force_nd)) kernel_call_args += ["o%i_data" % ipos] + list(
#strides = ", ".join("o%i_str[%i]"%(ipos, di) for di in xrange(force_nd)) local + "ostr[%i][%i]" %
#kernel_call_args.append( "%s, o%i_data" % (strides, ipos)) (ipos, di) for di in xrange(force_nd))
# strides = ", ".join("o%i_str[%i]"%(ipos, di) for di in xrange(force_nd))
# kernel_call_args.append( "%s, o%i_data" % (strides, ipos))
if self.verbose: if self.verbose:
print(""" print("""
std::cerr << " Running general version with %(force_nd)s dims\\n"; std::cerr << " Running general version with %(force_nd)s dims\\n";
"""%locals(), file=sio) """ % locals(), file=sio)
print("std::cerr << " + ' << " " << '.join(kernel_call_args)+' << "\\n";', file=sio) print("std::cerr << " + ' << " " << '.join(
kernel_call_args) + ' << "\\n";', file=sio)
# std::cerr << numEls << dims[0] << i0_data, i0_str[0] << o0_data, o0_str[0]\n; # std::cerr << numEls << dims[0] << i0_data, i0_str[0] << o0_data, o0_str[0]\n;
kernel_call_args = ", ".join(kernel_call_args) kernel_call_args = ", ".join(kernel_call_args)
...@@ -866,12 +940,13 @@ nd_collapse_[i]=0; ...@@ -866,12 +940,13 @@ nd_collapse_[i]=0;
else: else:
print(" return 0; " % locals(), file=sio) print(" return 0; " % locals(), file=sio)
print("if(numEls==0) return 0;", file=sio) print("if(numEls==0) return 0;", file=sio)
print("switch (nd_collapse==0?0:min(%(nd)s,nd_collapse)) {"%locals(), file=sio) print("switch (nd_collapse==0?0:min(%(nd)s,nd_collapse)) {" %
locals(), file=sio)
print("case 0: {", file=sio) print("case 0: {", file=sio)
launch_Ccontiguous(nodename, scalar_op, self.sync) launch_Ccontiguous(nodename, scalar_op, self.sync)
print(" } break;", file=sio) print(" } break;", file=sio)
for i in xrange(1, nd+1): for i in xrange(1, nd + 1):
print("case "+str(i)+": {", file=sio) print("case " + str(i) + ": {", file=sio)
launch_General(nodename, scalar_op, i, self.sync) launch_General(nodename, scalar_op, i, self.sync)
print(" } break;", file=sio) print(" } break;", file=sio)
...@@ -889,9 +964,10 @@ nd_collapse_[i]=0; ...@@ -889,9 +964,10 @@ nd_collapse_[i]=0;
#define INTMOD_POW2(a, b) (a & ((1<<b)-1)) #define INTMOD_POW2(a, b) (a & ((1<<b)-1))
""" """
kernels = "".join( kernels = "".join(
[self.c_src_kernel(node, nodename, x) for x in xrange(1, nd + 1)] [self.c_src_kernel(node, nodename, x)
+ [self.c_src_kernel_Ccontiguous(node, nodename)] for x in xrange(1, nd + 1)] +
+ [self.c_src_callkernel(node, nodename)]) [self.c_src_kernel_Ccontiguous(node, nodename)] +
[self.c_src_callkernel(node, nodename)])
return defines + kernels return defines + kernels
def c_support_code(self): def c_support_code(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论