提交 fac5f7dd authored 作者: Frederic Bastien's avatar Frederic Bastien

fix and reenable the collapsing of broadcast dimentions.

上级 3f596d97
...@@ -210,7 +210,7 @@ class RecAlgo(object): ...@@ -210,7 +210,7 @@ class RecAlgo(object):
return self.c_src_kernel(node, nodename) + self.c_src_callkernel(node, nodename) return self.c_src_kernel(node, nodename) + self.c_src_callkernel(node, nodename)
class NaiveAlgo(object): class NaiveAlgo(object):
verbose = 0 # 1 or 2 for more verbose output. verbose = 0 # 1, 2 or 3 for more verbose output.
cache_version = () cache_version = ()
cache_version = ('debug', 6, verbose) cache_version = ('debug', 6, verbose)
...@@ -248,9 +248,6 @@ class NaiveAlgo(object): ...@@ -248,9 +248,6 @@ class NaiveAlgo(object):
if _logical_scalar(i): if _logical_scalar(i):
print >> sio, " const float ii_i%i_value = i%i_data[0];" % (ipos, ipos) print >> sio, " const float ii_i%i_value = i%i_data[0];" % (ipos, ipos)
#TODO: insert code to check for strides of 1, and use a different loop
#loop over the elements to be treated by this kernel call #loop over the elements to be treated by this kernel call
print >> sio, " for (int i = idx; i < numEls; i += numThreads) {" print >> sio, " for (int i = idx; i < numEls; i += numThreads) {"
# calculate the data pointers for all arguments # calculate the data pointers for all arguments
...@@ -286,9 +283,6 @@ class NaiveAlgo(object): ...@@ -286,9 +283,6 @@ class NaiveAlgo(object):
print >> sio, " ", task_code print >> sio, " ", task_code
print >> sio, " }" print >> sio, " }"
#TODO: insert runtime stride checks that select the best loop order either here, or in
# the host code that launched the kernel (host code probably better spot)
#indent = " "*(4*d+7) #indent = " "*(4*d+7)
#for ipos, i in enumerate(node.inputs): #for ipos, i in enumerate(node.inputs):
#print >> sio, indent, "const float * i%i" % ipos, '= i%i_data', '' #print >> sio, indent, "const float * i%i" % ipos, '= i%i_data', ''
...@@ -635,8 +629,9 @@ class NaiveAlgo(object): ...@@ -635,8 +629,9 @@ class NaiveAlgo(object):
# like # like
# float *, int, int, int ... # float *, int, int, int ...
# #
# The second is to recognize when trailing (right-most in numpy) dimensions can be collapsed as # The second is to recognize when any dimensions can be collapsed as
# being contiguous... (confusing... read code) # being contiguous. That mean that we can merge that dimensions with another
# one for all inputs/outputs and have the same retusuls (confusing... read code)
# #
# The thrid is to make a special case for scalar element. We allow the collapsing of them. # The thrid is to make a special case for scalar element. We allow the collapsing of them.
# In the ccontiguous and not contiguous case, we use registers to lower the number of memory access. # In the ccontiguous and not contiguous case, we use registers to lower the number of memory access.
...@@ -644,6 +639,8 @@ class NaiveAlgo(object): ...@@ -644,6 +639,8 @@ class NaiveAlgo(object):
#TODO: make a special case for broadcasting, to store the data in shared memory. #TODO: make a special case for broadcasting, to store the data in shared memory.
nd = node.outputs[0].type.ndim nd = node.outputs[0].type.ndim
nb_inputs = len(node.inputs)
nb_outputs = len(node.outputs)
id_self = id(self) id_self = id(self)
d = dict() d = dict()
#input_params and output_params go into the function declaration/definition #input_params and output_params go into the function declaration/definition
...@@ -668,11 +665,7 @@ class NaiveAlgo(object): ...@@ -668,11 +665,7 @@ class NaiveAlgo(object):
{ {
//can we collapse dims[i] and dims[i-1] //can we collapse dims[i] and dims[i-1]
for(int i=nd-1;i>0;i--){ for(int i=nd-1;i>0;i--){
if(false && dims[i]==1 && strides[i]==0){// if(strides[i]*dims[i]==strides[i-1]){//the dims nd-1 are not strided again dimension nd
collapse[i]=1;
}else if(false && dims[i-1]==1 && strides[i-1]==0){
collapse[i]=1;
}else if(strides[i]*dims[i]==strides[i-1]){//the dims nd-1 are not strided again dimension nd
collapse[i]=1; collapse[i]=1;
}else collapse[i]=0; }else collapse[i]=0;
} }
...@@ -704,9 +697,85 @@ class NaiveAlgo(object): ...@@ -704,9 +697,85 @@ class NaiveAlgo(object):
std::cerr << " %(ipos)s data strides" << std::cerr << " %(ipos)s data strides" <<
""" %locals() + " << ' ' << ".join(["o%s_data"%ipos] """ %locals() + " << ' ' << ".join(["o%s_data"%ipos]
+ list("o%s_str[%i]"%(ipos, di) for di in xrange(nd))) + ''' << "\\n"; ''' + list("o%s_str[%i]"%(ipos, di) for di in xrange(nd))) + ''' << "\\n"; '''
# collapse dimension that are broadcast in all inputs.
# need to be done before contiguous collapse as it will break it.
# do the dimensions and the strides
print >> sio, """
int local_dims[%(nd)s];
int local_str[%(nb_inputs)s][%(nd)s];
int local_ostr[%(nb_inputs)s][%(nd)s];
int nd_collapse = %(nd)s;
for(int i=0;i<%(nd)s;i++){//init new dim
local_dims[i]=dims[i];
}
"""%locals()
for ipos in xrange(len(node.inputs)):
print >> sio, """
for(int i=0;i<%(nd)s;i++){//init new strides
local_str[%(ipos)s][i]=i%(ipos)s_str[i];
}
"""%locals()
for ipos in xrange(len(node.outputs)):
print >> sio, """
for(int i=0;i<%(nd)s;i++){//init new strides
local_ostr[%(ipos)s][i]=o%(ipos)s_str[i];
}
"""%locals()
if self.verbose>2:
print >>sio, 'std::cerr <<"before broadcast collapse\\n";'
print >>sio, 'std::cerr<< "nd_collapse "<< nd_collapse << "\\n"; '
print >> sio, 'std::cerr << "local_dims";'
for d in xrange(nd):
print >> sio, 'std::cerr << " " << local_dims[%(d)s]; '%locals()
print >> sio, 'std::cerr << "\\n";'
for ipos in xrange(len(node.inputs)):
print >> sio, 'std::cerr << " local_str inputs %(ipos)s: " <<'%locals()+' << " " << '.join(["local_str[%(ipos)s][%(x)s]"%locals() for x in range(nd)])+'<<"\\n";'
for ipos in xrange(len(node.outputs)):
print >> sio, 'std::cerr << " local_ostr inputs %(ipos)s: " <<'%locals()+' << " " << '.join(["local_ostr[%(ipos)s][%(x)s]"%locals() for x in range(nd)])+'<<"\\n";'
print >> sio, """
for(int id=0;id<nd_collapse;id++){
bool all_broadcast=true;
for(int input_id=0;input_id<%(nb_inputs)s;input_id++){
if(local_str[input_id][id]!=0 || local_dims[id]!=1) all_broadcast= false;
}
for(int input_id=0;input_id<%(nb_outputs)s;input_id++){
if(local_ostr[input_id][id]!=0 || local_dims[id]!=1) all_broadcast= false;
}
if(all_broadcast){
for(int j=id+1;j<nd_collapse;j++)//remove dims i from the array
local_dims[j-1]=local_dims[j];
for(int input_id=0;input_id<%(nb_inputs)s;input_id++){
for(int j=id+1;j<nd_collapse;j++){//remove dims i from the array
local_str[input_id][j-1]=local_str[input_id][j];
}
}
for(int output_id=0;output_id<%(nb_outputs)s;output_id++){
for(int j=id+1;j<nd_collapse;j++){//remove dims i from the array
local_ostr[output_id][j-1]=local_ostr[output_id][j];
}
}
nd_collapse--; id--;
}
}
"""%locals()
if self.verbose>2:
print >>sio, 'std::cerr <<"after broadcast collapse\\n";'
print >>sio, 'std::cerr<< "nd_collapse "<< nd_collapse << "\\n"; '
print >> sio, 'std::cerr << "local_dims";'
for d in xrange(nd):
print >> sio, 'std::cerr << " " << local_dims[%(d)s]; '%locals()
print >> sio, 'std::cerr << "\\n";'
for ipos in xrange(len(node.inputs)):
print >> sio, 'std::cerr << " local_str %(ipos)s: " <<'%locals()+' << " " << '.join(["local_str[%(ipos)s][%(x)s]"%locals() for x in range(nd)])+'<<"\\n";'
for ipos in xrange(len(node.outputs)):
print >> sio, 'std::cerr << " local_ostr %(ipos)s: " <<'%locals()+' << " " << '.join(["local_ostr[%(ipos)s][%(x)s]"%locals() for x in range(nd)])+'<<"\\n";'
# collapse contiguous dimensions (ignoring scalars, generic version(collapse any dimensions, right, left, middle)) # collapse contiguous dimensions (ignoring scalars, generic version(collapse any dimensions, right, left, middle))
# this is a good idea because [we assume that] the output has been allocated c_contiguous # this is a good idea because we make less index calculation in the gpu.
print >> sio, "int nd_collapse_[%(nd)s] = {"%locals() +','.join(['1' for x in range(nd)]) +"};" print >> sio, "int nd_collapse_[%(nd)s] = {"%locals() +','.join(['1' for x in range(nd)]) +"};"
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
...@@ -714,8 +783,8 @@ class NaiveAlgo(object): ...@@ -714,8 +783,8 @@ class NaiveAlgo(object):
print >> sio, """ print >> sio, """
int nd_collapse_%(ipos)s[%(nd)s] = {"""%locals() +','.join(['1' for x in range(nd)]) +"};" int nd_collapse_%(ipos)s[%(nd)s] = {"""%locals() +','.join(['1' for x in range(nd)]) +"};"
print >> sio, """ print >> sio, """
can_collapse_%(nodename)s(%(nd)s, dims, i%(ipos)s_str, nd_collapse_%(ipos)s); can_collapse_%(nodename)s(nd_collapse, local_dims, local_str[%(ipos)s], nd_collapse_%(ipos)s);
for(int i=0;i<%(nd)s;i++){ for(int i=0;i<nd_collapse;i++){
if(nd_collapse_%(ipos)s[i]==0) if(nd_collapse_%(ipos)s[i]==0)
nd_collapse_[i]=0; nd_collapse_[i]=0;
} }
...@@ -731,76 +800,64 @@ nd_collapse_[i]=0; ...@@ -731,76 +800,64 @@ nd_collapse_[i]=0;
"""%locals() """%locals()
print >>sio, ' << " " << '.join(["nd_collapse_["%locals()+str(i)+"]" for i in range(nd)]) print >>sio, ' << " " << '.join(["nd_collapse_["%locals()+str(i)+"]" for i in range(nd)])
print >>sio, '<< "\\n";' print >>sio, '<< "\\n";'
print >> sio, """
int nd_collapse=%(nd)s;
for(int i=1;i<%(nd)s;i++){
if(nd_collapse_[i]==1)nd_collapse--;
}
if(nd_collapse==1 && """%locals()
print >> sio, " && ".join([ "i%(ipos)s_str[%(nd)s-1]==1 "%locals()for x in range(len(node.inputs))])
print >> sio,"""){nd_collapse=0;} """
if self.verbose:
print >> sio, """std::cerr << "nd_collapse " << nd_collapse << "\\n"; """ %locals()
# set the new dims.
print >> sio, "int local_dims[%(nd)s];"%locals()
print >> sio, """
for(int i=0;i<%(nd)s;i++){//init new dim
local_dims[i]=dims[i];
}
for(int i=%(nd)s-1;i>0;i--){
if(nd_collapse_[i]==1){
local_dims[i-1]*=local_dims[i];//set new dims
for(int j=i+1;j<%(nd)s;j++)//remove dims i from the array
local_dims[j-1]=local_dims[j];
}
}
"""%locals()
if self.verbose>1:
for d in xrange(nd):
print >> sio, 'std::cerr << "local_dims %(d)s " << local_dims[%(d)s] << "\\n"; '%locals()
# set the new stride. # update the local stride.
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
print >> sio, """ print >> sio, """
int local_i%(ipos)s_str[%(nd)s]; for(int i=nd_collapse-1;i>0;i--){
"""%locals()
print >> sio, """
for(int i=0;i<%(nd)s;i++){//init new strides
local_i%(ipos)s_str[i]=i%(ipos)s_str[i];
}
for(int i=%(nd)s-1;i>0;i--){
if(nd_collapse_[i]==1){ if(nd_collapse_[i]==1){
local_i%(ipos)s_str[i-1]=local_i%(ipos)s_str[i];//set new strides local_str[%(ipos)s][i-1]=local_str[%(ipos)s][i];//set new strides
for(int j=i+1;j<%(nd)s;j++)//remove stride i from the array for(int j=i+1;j<nd_collapse;j++)//remove stride i from the array
local_i%(ipos)s_str[j-1]=local_i%(ipos)s_str[j]; local_str[%(ipos)s][j-1]=local_str[%(ipos)s][j];
} }
} }
"""%locals() """%locals()
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
print >> sio, "int local_o%(ipos)s_str[%(nd)s];"%locals()
print >> sio, """ print >> sio, """
for(int i=0;i<%(nd)s;i++){//init new strides for(int i=nd_collapse-1;i>0;i--){
local_o%(ipos)s_str[i]=o%(ipos)s_str[i];
}
for(int i=%(nd)s-1;i>0;i--){
if(nd_collapse_[i]==1){ if(nd_collapse_[i]==1){
local_o%(ipos)s_str[i-1]=local_o%(ipos)s_str[i];//set new strides local_ostr[%(ipos)s][i-1]=local_ostr[%(ipos)s][i];//set new strides
for(int j=i+1;j<%(nd)s;j++)//remove stride i from the array for(int j=i+1;j<nd_collapse;j++)//remove stride i from the array
local_o%(ipos)s_str[j-1]=local_o%(ipos)s_str[j]; local_ostr[%(ipos)s][j-1]=local_ostr[%(ipos)s][j];
} }
} }
"""%locals() """%locals()
# update the local dims.
print >> sio, """
for(int i=nd_collapse-1;i>0;i--){
if(nd_collapse_[i]==1){
local_dims[i-1]*=local_dims[i];//set new dims
for(int j=i+1;j<nd_collapse;j++)//remove dims i from the array
local_dims[j-1]=local_dims[j];
}
}
"""%locals()
#update the new number of dim
print >> sio, """
for(int i=1, end=nd_collapse;i<end;i++){
if(nd_collapse_[i]==1)nd_collapse--;
}
if(nd_collapse == 1 && """%locals()
print >> sio, " && ".join(["local_str[%(ipos)s][nd_collapse-1]==1 "%locals()for ipos in range(len(node.inputs)) if not _logical_scalar(node.inputs[ipos])]+
["local_ostr[%(ipos)s][nd_collapse-1]==1 "%locals()for ipos in range(len(node.outputs)) if not _logical_scalar(node.outputs[ipos])])
print >> sio,"""){nd_collapse=0;} """
if self.verbose:
print >> sio, 'std::cerr <<"after can_collapse\\n";'
print >> sio, """std::cerr << "nd_collapse " << nd_collapse << "\\n"; """ %locals()
if self.verbose>1: if self.verbose>1:
for ipos in ["i"+ str(x) for x in xrange(len(node.inputs))]+["o"+ str(x) for x in xrange(len(node.outputs))]: for d in xrange(nd):
print >> sio, 'std::cerr << " local_%(ipos)s_str " <<'%locals()+' << " " << '.join(["local_%(ipos)s_str[%(x)s]"%locals() for x in range(nd)])+'<<"\\n";' print >> sio, 'std::cerr << " " << local_dims[%(d)s]; '%locals()
print >> sio, 'std::cerr << "\\n";'
for ipos in xrange(len(node.inputs)):
print >> sio, 'std::cerr << " local_str %(ipos)s: " <<'%locals()+' << " " << '.join(["local_str[%(ipos)s][%(x)s]"%locals() for x in range(nd)])+'<<"\\n";'
for ipos in xrange(len(node.outputs)):
print >> sio, 'std::cerr << " local_ostr %(ipos)s: " <<'%locals()+' << " " << '.join(["local_ostr[%(ipos)s][%(x)s]"%locals() for x in range(nd)])+'<<"\\n";'
def launch_Ccontiguous(nodename, id_self, scalar_op): def launch_Ccontiguous(nodename, id_self, scalar_op):
...@@ -837,11 +894,11 @@ nd_collapse_[i]=0; ...@@ -837,11 +894,11 @@ nd_collapse_[i]=0;
kernel_call_args = ["numEls"] kernel_call_args = ["numEls"]
kernel_call_args.extend(local+"dims[%i]"%di for di in xrange(force_nd)) kernel_call_args.extend(local+"dims[%i]"%di for di in xrange(force_nd))
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
kernel_call_args+=["i%i_data"%ipos] + list(local+"i%i_str[%i]"%(ipos, di) for di in xrange(force_nd)) kernel_call_args+=["i%i_data"%ipos] + list(local+"str[%i][%i]"%(ipos, di) for di in xrange(force_nd))
#strides = ", ".join("i%i_str[%i]"%(ipos, di) for di in xrange(force_nd)) #strides = ", ".join("i%i_str[%i]"%(ipos, di) for di in xrange(force_nd))
#kernel_call_args.append( "%s, i%i_data" % (strides, ipos)) #kernel_call_args.append( "%s, i%i_data" % (strides, ipos))
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
kernel_call_args+=["o%i_data"%ipos] + list(local+"o%i_str[%i]"%(ipos, di) for di in xrange(force_nd)) kernel_call_args+=["o%i_data"%ipos] + list(local+"ostr[%i][%i]"%(ipos, di) for di in xrange(force_nd))
#strides = ", ".join("o%i_str[%i]"%(ipos, di) for di in xrange(force_nd)) #strides = ", ".join("o%i_str[%i]"%(ipos, di) for di in xrange(force_nd))
#kernel_call_args.append( "%s, o%i_data" % (strides, ipos)) #kernel_call_args.append( "%s, o%i_data" % (strides, ipos))
if self.verbose: if self.verbose:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论