提交 48f4db7f authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Remove py310 only strict arg to zip

上级 f454836b
...@@ -19,9 +19,7 @@ def compute_itershape( ...@@ -19,9 +19,7 @@ def compute_itershape(
ndim = len(in_shapes[0]) ndim = len(in_shapes[0])
shape = [None] * ndim shape = [None] * ndim
for i in range(ndim): for i in range(ndim):
for j, (bc, in_shape) in enumerate( for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)):
zip(broadcast_pattern, in_shapes, strict=True)
):
length = in_shape[i] length = in_shape[i]
if bc[i]: if bc[i]:
with builder.if_then( with builder.if_then(
...@@ -151,14 +149,10 @@ def make_loop_call( ...@@ -151,14 +149,10 @@ def make_loop_call(
# input_scope_set = mod.add_metadata([input_scope, output_scope]) # input_scope_set = mod.add_metadata([input_scope, output_scope])
# output_scope_set = mod.add_metadata([input_scope, output_scope]) # output_scope_set = mod.add_metadata([input_scope, output_scope])
inputs = tuple( inputs = tuple(extract_array(aryty, ary) for aryty, ary in zip(input_types, inputs))
extract_array(aryty, ary)
for aryty, ary in zip(input_types, inputs, strict=True)
)
outputs = tuple( outputs = tuple(
extract_array(aryty, ary) extract_array(aryty, ary) for aryty, ary in zip(output_types, outputs)
for aryty, ary in zip(output_types, outputs, strict=True)
) )
zero = ir.Constant(ir.IntType(64), 0) zero = ir.Constant(ir.IntType(64), 0)
...@@ -189,8 +183,8 @@ def make_loop_call( ...@@ -189,8 +183,8 @@ def make_loop_call(
# Load values from input arrays # Load values from input arrays
input_vals = [] input_vals = []
for array_info, bc in zip(inputs, input_bc, strict=True): for array_info, bc in zip(inputs, input_bc):
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)]
ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe) ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
val = builder.load(ptr) val = builder.load(ptr)
# val.set_metadata("alias.scope", input_scope_set) # val.set_metadata("alias.scope", input_scope_set)
...@@ -210,9 +204,7 @@ def make_loop_call( ...@@ -210,9 +204,7 @@ def make_loop_call(
output_values = [output_values] output_values = [output_values]
# Update output value or accumulators respectively # Update output value or accumulators respectively
for i, ((accu, _), value) in enumerate( for i, ((accu, _), value) in enumerate(zip(output_accumulator, output_values)):
zip(output_accumulator, output_values, strict=True)
):
if accu is not None: if accu is not None:
load = builder.load(accu) load = builder.load(accu)
# load.set_metadata("alias.scope", output_scope_set) # load.set_metadata("alias.scope", output_scope_set)
...@@ -223,9 +215,7 @@ def make_loop_call( ...@@ -223,9 +215,7 @@ def make_loop_call(
# store.set_metadata("alias.scope", output_scope_set) # store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set) # store.set_metadata("noalias", input_scope_set)
else: else:
idxs_bc = [ idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, output_bc[i])]
zero if bc else idx for idx, bc in zip(idxs, output_bc[i], strict=True)
]
ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc) ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc)
# store = builder.store(value, ptr) # store = builder.store(value, ptr)
arrayobj.store_item(context, builder, output_types[i], value, ptr) arrayobj.store_item(context, builder, output_types[i], value, ptr)
...@@ -237,8 +227,7 @@ def make_loop_call( ...@@ -237,8 +227,7 @@ def make_loop_call(
for output, (accu, accu_depth) in enumerate(output_accumulator): for output, (accu, accu_depth) in enumerate(output_accumulator):
if accu_depth == depth: if accu_depth == depth:
idxs_bc = [ idxs_bc = [
zero if bc else idx zero if bc else idx for idx, bc in zip(idxs, output_bc[output])
for idx, bc in zip(idxs, output_bc[output], strict=True)
] ]
ptr = cgutils.get_item_pointer2( ptr = cgutils.get_item_pointer2(
context, builder, *outputs[output], idxs_bc context, builder, *outputs[output], idxs_bc
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论