提交 c15e7aaa authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Make non-strict zips strict in tensor/elemwise_cgen

上级 47c09433
...@@ -209,7 +209,13 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): ...@@ -209,7 +209,13 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
) )
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): def make_loop(
loop_orders: list[tuple[int | str, ...]],
dtypes: list,
loop_tasks: list,
sub: dict[str, str],
openmp: bool = False,
):
""" """
Make a nested loop over several arrays and associate specific code Make a nested loop over several arrays and associate specific code
to each level of nesting. to each level of nesting.
...@@ -227,7 +233,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): ...@@ -227,7 +233,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
string is code to be executed before the ith loop starts, the second string is code to be executed before the ith loop starts, the second
one contains code to be executed just before going to the next element one contains code to be executed just before going to the next element
of the ith dimension. of the ith dimension.
The last element if loop_tasks is a single string, containing code The last element of loop_tasks is a single string, containing code
to be executed at the very end. to be executed at the very end.
sub : dictionary sub : dictionary
Maps 'lv#' to a suitable variable name. Maps 'lv#' to a suitable variable name.
...@@ -260,7 +266,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): ...@@ -260,7 +266,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
}} }}
""" """
preloops = {} preloops: dict[int, str] = {}
for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)): for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)):
for j, index in enumerate(loop_order): for j, index in enumerate(loop_order):
if index != "x": if index != "x":
...@@ -277,16 +283,8 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): ...@@ -277,16 +283,8 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
s = "" s = ""
for i, (pre_task, task), indices in reversed( tasks_indices = zip(loop_tasks[:-1], zip(*loop_orders, strict=True), strict=True)
list( for i, ((pre_task, task), indices) in reversed(list(enumerate(tasks_indices))):
zip(
range(len(loop_tasks) - 1),
loop_tasks,
list(zip(*loop_orders, strict=True)),
strict=False,
)
)
):
s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i) s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i)
s += loop_tasks[-1] s += loop_tasks[-1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论