提交 73e46362 authored 作者: James Bergstra's avatar James Bergstra

not counting scalars towards nd_collapse

上级 ac703c5b
......@@ -15,6 +15,8 @@ def debug(*msg):
_logger.debug(_logger_name+'DEBUG: '+' '.join(str(m) for m in msg))
def _logical_scalar(x):
return all(x.type.broadcastable)
class RecAlgo(object):
def c_src_kernel(self, node, nodename):
......@@ -22,8 +24,6 @@ class RecAlgo(object):
sio = StringIO.StringIO()
#print 'C_SRC_KERNEL', sio.getvalue()
def _logical_scalar(x):
return all(x.type.broadcastable)
for ipos, i in enumerate(node.inputs):
print >> sio, "// Input ", ipos, str(i.type)
......@@ -428,9 +428,6 @@ class NaiveAlgo(object):
if nd not in (4,):
return sio.getvalue()
def _logical_scalar(x):
return all(x.type.broadcastable)
# print some leading comments to make the code easier to read
for ipos, i in enumerate(node.inputs):
print >> sio, "// Input ", ipos, str(i.type)
......@@ -599,6 +596,13 @@ class NaiveAlgo(object):
print >> sio, " const int idx = blockIdx.x * blockDim.x + threadIdx.x;"
print >> sio, " const int numThreads = blockDim.x * gridDim.x;"
# For each input that is a scalar which has been broadcasted to a tensor,
# load it into a local variable
for ipos, i in enumerate(node.inputs):
if _logical_scalar(i):
print >> sio, " const float ii_i%i_value = i%i_data[0];" % (ipos, ipos)
#loop over the elements to be treated by this kernel call
print >> sio, " for (int i = idx; i < numEls; i += numThreads) {"
# perform the scalar operation on the input and output references
......@@ -608,7 +612,8 @@ class NaiveAlgo(object):
[scalar.Scalar(dtype = input.type.dtype)() for input in node.inputs],
[scalar.Scalar(dtype = output.type.dtype)() for output in node.outputs])
, nodename + '_scalar_'
, ['i%i_data[i]'%ipos for ipos, i in enumerate(node.inputs)]
#, ['i%i_data[i]'%ipos for ipos, i in enumerate(node.inputs)]
, [('ii_i%i_value' if _logical_scalar(i) else 'i%i_data[i]')%ipos for ipos, i in enumerate(node.inputs)]
, ['o%i_data[i]'%ipos for ipos, i in enumerate(node.outputs)]
, sub=dict(fail='return;')) #TODO: set a failure code somehow!!!
print >> sio, " ", task_code
......@@ -685,6 +690,7 @@ class NaiveAlgo(object):
print >> sio, "int nd_collapse = 0;" #because the outputs are assumed to be c_contiguous
print >> sio, "int nd_collapse_size = numEls;" #because the outputs are assumed to be c_contiguous
for ipos in xrange(len(node.inputs)):
if not _logical_scalar(node.inputs[ipos]):
print >> sio, """
int nd_collapse_size_%(ipos)s;
int nd_collapse_%(ipos)s = c_contiguous_beyond_%(nodename)s(%(nd)s, dims, i%(ipos)s_str, nd_collapse_size_%(ipos)s);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论