提交 7d87c82c authored 作者: AlOa's avatar AlOa

Add code for scalar case(No loop)

上级 375b82ef
......@@ -1047,6 +1047,9 @@ class Elemwise(OpenMPOp):
%(undefs)s
}
""" % locals()
loop_orders = orders + [range(nnested)] * len(real_onames)
dtypes = (idtypes + list(real_odtypes))
if all([o.ndim <= 1 for o in node.outputs] or
# Use simpler code when output ndim == 0 or 1
# or for broadcated scalar.
......@@ -1055,17 +1058,45 @@ class Elemwise(OpenMPOp):
all_code = [("", "")] * (nnested - 1) + [("", code)] + [""]
else:
all_code = [code]
loop = cgen.make_loop(
loop_orders=orders + [range(nnested)] * len(real_onames),
dtypes=(idtypes + list(real_odtypes)),
loop_tasks=all_code,
sub=sub, reduce=False, openmp=self.openmp)
if len(all_code) == 1:
#No loops
task_decl = "".join([
"%s& %s_i = *%s_iter;\n" % (dtype, name, name)
for name, dtype in izip(inames + list(real_onames),
idtypes + list(real_odtypes))])
preloops = {}
for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)):
for j, index in enumerate(loop_order):
if index != 'x':
preloops.setdefault(j, "")
preloops[j] += ("%%(lv%(i)s)s_iter = (%(dtype)s*)(PyArray_DATA(%%(lv%(i)s)s));\n" % locals()) % sub
break
else: # all broadcastable
preloops.setdefault(0, "")
preloops[0] += ("%%(lv%(i)s)s_iter = (%(dtype)s*)(PyArray_DATA(%%(lv%(i)s)s));\n" % locals()) % sub
init_array = preloops.get(0, " ")
loop = """
{
%(defines)s
%(init_array)s
%(task_decl)s
%(task_code)s
%(undefs)s
}
""" % locals()
else:
loop = cgen.make_loop(
loop_orders=loop_orders,
dtypes=dtypes,
loop_tasks=all_code,
sub=sub, reduce=False, openmp=self.openmp)
else:
loop = cgen.make_reordered_loop(
init_loop_orders=orders + [range(nnested)] * len(real_onames),
init_loop_orders=loop_orders,
olv_index=olv_index,
dtypes=(idtypes + list(real_odtypes)),
dtypes=dtypes,
inner_task=code,
sub=sub, openmp=self.openmp)
......
......@@ -254,10 +254,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, reduce=False, openmp=None):
preloops.setdefault(0, "")
preloops[0] += ("%%(lv%(i)s)s_iter = (%(dtype)s*)(PyArray_DATA(%%(lv%(i)s)s));\n" % locals()) % sub
if len(loop_tasks) == 1:
s = preloops.get(0, "")
else:
s = ""
s = ""
if reduce:
loop_over = loop_over_reduce
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论