提交 52b98b42 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Minor style fixes

上级 ffbff57b
...@@ -553,25 +553,25 @@ class NaiveAlgo(object): ...@@ -553,25 +553,25 @@ class NaiveAlgo(object):
for(int i=0;i<%(nd)s;i++){//init new dim for(int i=0;i<%(nd)s;i++){//init new dim
local_dims[i]=dims[i]; local_dims[i]=dims[i];
} }
"""%locals() """ % locals()
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
print >> sio, """ print >> sio, """
for(int i=0;i<%(nd)s;i++){//init new strides for(int i=0;i<%(nd)s;i++){//init new strides
local_str[%(ipos)s][i]=i%(ipos)s_str[i]; local_str[%(ipos)s][i]=i%(ipos)s_str[i];
} }
"""%locals() """ % locals()
for ipos in xrange(len(node.outputs)): for ipos in xrange(len(node.outputs)):
print >> sio, """ print >> sio, """
for(int i=0;i<%(nd)s;i++){//init new strides for(int i=0;i<%(nd)s;i++){//init new strides
local_ostr[%(ipos)s][i]=o%(ipos)s_str[i]; local_ostr[%(ipos)s][i]=o%(ipos)s_str[i];
} }
"""%locals() """ % locals()
if self.verbose>2: if self.verbose>2:
print >>sio, 'std::cerr <<"before broadcast collapse\\n";' print >>sio, 'std::cerr <<"before broadcast collapse\\n";'
print >>sio, 'std::cerr<< "nd_collapse "<< nd_collapse << "\\n"; ' print >>sio, 'std::cerr<< "nd_collapse "<< nd_collapse << "\\n"; '
print >> sio, 'std::cerr << "local_dims";' print >> sio, 'std::cerr << "local_dims";'
for d in xrange(nd): for d in xrange(nd):
print >> sio, 'std::cerr << " " << local_dims[%(d)s]; '%locals() print >> sio, 'std::cerr << " " << local_dims[%(d)s]; ' % locals()
print >> sio, 'std::cerr << "\\n";' print >> sio, 'std::cerr << "\\n";'
for ipos in xrange(len(node.inputs)): for ipos in xrange(len(node.inputs)):
...@@ -858,9 +858,13 @@ nd_collapse_[i]=0; ...@@ -858,9 +858,13 @@ nd_collapse_[i]=0;
//standard elemwise size checks //standard elemwise size checks
""" %locals() """ %locals()
if nd > 0: if nd > 0:
print >> sio, """int dims[%(nd)s] = {%(initial_dims)s};""" %locals() print >> sio, """
int dims[%(nd)s] = {%(initial_dims)s};
""" % locals()
else: else:
print >> sio, """int *dims = NULL;""" print >> sio, """
int *dims = NULL;
"""
#check that all inputs have valid dimensions #check that all inputs have valid dimensions
emitted_inames = {} emitted_inames = {}
...@@ -871,9 +875,13 @@ nd_collapse_[i]=0; ...@@ -871,9 +875,13 @@ nd_collapse_[i]=0;
broadcasts = ', '.join(map(str,map(int,node.inputs[id].broadcastable))) broadcasts = ', '.join(map(str,map(int,node.inputs[id].broadcastable)))
nd = node.inputs[id].ndim nd = node.inputs[id].ndim
if nd > 0: if nd > 0:
print >> sio, """int broadcasts_%(iname)s[%(nd)s] = {%(broadcasts)s};""" %locals() print >> sio, """
int broadcasts_%(iname)s[%(nd)s] = {%(broadcasts)s};
""" % locals()
else: else:
print >> sio, """int *broadcasts_%(iname)s = NULL;""" %locals() print >> sio, """
int *broadcasts_%(iname)s = NULL;
""" % locals()
emitted_inames[iname] = node.inputs[id] emitted_inames[iname] = node.inputs[id]
#check that all inputs have valid dimensions #check that all inputs have valid dimensions
emitted_inames = {} emitted_inames = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论