提交 8a6d2aae authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Rewrite for/append as list comprehensions

上级 bf73f8a0
......@@ -906,11 +906,10 @@ def _get_preallocated_maps(
name = f"strided{tuple(steps)}"
for r in considered_outputs:
if r in init_strided:
strides = []
shapes = []
for i, size in enumerate(r_vals[r].shape):
shapes.append(slice(None, size, None))
strides.append(slice(None, None, steps[i]))
shapes = [slice(None, size, None) for size in r_vals[r].shape]
strides = [
slice(None, None, steps[i]) for i in range(r_vals[r].ndim)
]
r_buf = init_strided[r]
......
......@@ -247,18 +247,10 @@ def function(
"""
if isinstance(outputs, dict):
output_items = list(outputs.items())
assert all(isinstance(k, str) for k in outputs)
for item_pair in output_items:
assert isinstance(item_pair[0], str)
output_items_sorted = sorted(output_items)
output_keys = []
outputs = []
for pair in output_items_sorted:
output_keys.append(pair[0])
outputs.append(pair[1])
output_keys = sorted(outputs)
outputs = [outputs[key] for key in output_keys]
else:
output_keys = None
......
......@@ -212,18 +212,14 @@ def std_fgraph(
found_updates.extend(map(SymbolicOutput, updates))
elif fgraph is None:
input_vars = []
# If one of the inputs is non-atomic (i.e. has a non-`None` `Variable.owner`),
# then we need to create/clone the graph starting at these inputs.
# The result will be atomic versions of the given inputs connected to
# the same outputs.
# Otherwise, when all the inputs are already atomic, there's no need to
# clone the graph.
clone = force_clone
for spec in input_specs:
input_vars.append(spec.variable)
clone |= spec.variable.owner is not None
input_vars = [spec.variable for spec in input_specs]
clone = force_clone or any(var.owner is not None for var in input_vars)
fgraph = FunctionGraph(
input_vars,
......
......@@ -1204,8 +1204,7 @@ class ProfileStats:
compute_map[var][0] = 0
for k_remove, v_remove in viewedby_remove.items():
for i in v_remove:
viewed_by[k_remove].append(i)
viewed_by[k_remove].extend(v_remove)
for k_add, v_add in viewedby_add.items():
for i in v_add:
......@@ -1215,15 +1214,16 @@ class ProfileStats:
del view_of[k]
# two data structure used to mimic Python gc
viewed_by = {} # {var1: [vars that view var1]}
# * {var1: [vars that view var1]}
# The len of the list is the value of python ref
# count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo
for var in fgraph.variables:
viewed_by[var] = []
view_of = {} # {var1: original var viewed by var1}
viewed_by = {var: [] for var in fgraph.variables}
# * {var1: original var viewed by var1}
# The original mean that we don't keep track of all the intermediate
# relationship in the view.
view_of = {}
min_memory_generator(executable_nodes, viewed_by, view_of)
......
......@@ -1474,9 +1474,8 @@ def general_toposort(
_clients: dict[T, list[T]] = {}
sources: deque[T] = deque()
search_res_len: int = 0
search_res_len = len(search_res)
for snode, children in search_res:
search_res_len += 1
if children:
for child in children:
_clients.setdefault(child, []).append(snode)
......
......@@ -270,8 +270,10 @@ class FunctionGraph(MetaObject):
self.execute_callbacks("on_prune", apply_node, reason)
for i, in_var in enumerate(apply_node.inputs):
removal_stack.append((in_var, (apply_node, i)))
removal_stack.extend(
(in_var, (apply_node, i))
for i, in_var in enumerate(apply_node.inputs)
)
if remove_if_empty:
del clients[var]
......
......@@ -479,9 +479,9 @@ class SequentialGraphRewriter(GraphRewriter, UserList):
new_sub_profile.append(p[6][idx])
new_rewrite = SequentialGraphRewriter(*new_l)
new_nb_nodes = []
for p1, p2 in zip(prof1[8], prof2[8]):
new_nb_nodes.append((p1[0] + p2[0], p1[1] + p2[1]))
new_nb_nodes = [
(p1[0] + p2[0], p1[1] + p2[1]) for p1, p2 in zip(prof1[8], prof2[8])
]
new_nb_nodes.extend(prof1[8][len(new_nb_nodes) :])
new_nb_nodes.extend(prof2[8][len(new_nb_nodes) :])
......@@ -960,9 +960,9 @@ class MetaNodeRewriter(NodeRewriter):
tracks = rewriter.tracks()
if tracks:
self._tracks.extend(tracks)
for c in tracks:
self.track_dict[c].append(rewriter)
self._tracks.append(c)
for tag in tag_list:
self.tag_dict[tag].append(rewriter)
......
......@@ -524,12 +524,13 @@ class WrapLinker(Linker):
thunk_groups = list(zip(*thunk_lists))
order = [x[0] for x in zip(*order_lists)]
to_reset = []
for thunks, node in zip(thunk_groups, order):
for j, output in enumerate(node.outputs):
if output in no_recycling:
for thunk in thunks:
to_reset.append(thunk.outputs[j])
to_reset = [
thunk.outputs[j]
for thunks, node in zip(thunk_groups, order)
for j, output in enumerate(node.outputs)
if output in no_recycling
for thunk in thunks
]
wrapper = self.wrapper
pre = self.pre
......@@ -696,10 +697,7 @@ class JITLinker(PerformLinker):
computed, last_user = gc_helper(nodes)
if self.allow_gc:
post_thunk_old_storage = []
for node in nodes:
post_thunk_old_storage.append(
post_thunk_old_storage = [
[
storage_map[input]
for input in node.inputs
......@@ -707,7 +705,8 @@ class JITLinker(PerformLinker):
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
for node in nodes
]
else:
post_thunk_old_storage = None
......
......@@ -1129,19 +1129,18 @@ class CLinker(Linker):
)
def get_init_tasks(self):
init_tasks = []
tasks = []
vars = [v for v in self.variables if v not in self.consts]
id = 1
for v in self.variables:
if v in self.consts:
continue
init_tasks.append((v, "init", id))
tasks.append((v, "get", id + 1))
id += 2
for node in self.node_order:
tasks.append((node, "code", id))
init_tasks.append((node, "init", id + 1))
id += 2
init_tasks = [(v, "init", id + 2 * i) for i, v in enumerate(vars)]
tasks = [(v, "get", id + 2 * i + 1) for i, v in enumerate(vars)]
id += 2 * len(vars)
tasks.extend(
(node, "code", id + 2 * i) for i, node in enumerate(self.node_order)
)
init_tasks.extend(
(node, "init", id + 2 * i + 1) for i, node in enumerate(self.node_order)
)
return init_tasks, tasks
def make_thunk(
......@@ -1492,12 +1491,11 @@ class CLinker(Linker):
# graph's information used to compute the key. If we mistakenly
# pretend that inputs with clients don't have any, were are only using
# those inputs more than once to compute the key.
for ipos, var in [
(i, var)
for i, var in enumerate(fgraph.inputs)
sig.extend(
(var.type, in_sig(var, -1, ipos))
for ipos, var in enumerate(fgraph.inputs)
if not len(fgraph.clients[var])
]:
sig.append((var.type, in_sig(var, -1, ipos)))
)
# crystalize the signature and version
sig = tuple(sig)
......
......@@ -220,12 +220,7 @@ int main( int argc, const char* argv[] )
def lquote_macro(txt: str) -> str:
"""Turn the last line of text into a ``\\``-commented line."""
res = []
spl = txt.split("\n")
for l in spl[:-1]:
res.append(l + " \\")
res.append(spl[-1])
return "\n".join(res)
return " \\\n".join(txt.split("\n"))
def get_sub_macros(sub: dict[str, str]) -> tuple[str, str]:
......@@ -240,21 +235,17 @@ def get_sub_macros(sub: dict[str, str]) -> tuple[str, str]:
return "\n".join(define_macros), "\n".join(undef_macros)
def get_io_macros(
inputs: list[str], outputs: list[str]
) -> tuple[list[str]] | tuple[str, str]:
define_macros = []
undef_macros = []
def get_io_macros(inputs: list[str], outputs: list[str]) -> tuple[str, str]:
define_inputs = [f"#define INPUT_{int(i)} {inp}" for i, inp in enumerate(inputs)]
define_outputs = [f"#define OUTPUT_{int(i)} {out}" for i, out in enumerate(outputs)]
for i, inp in enumerate(inputs):
define_macros.append(f"#define INPUT_{int(i)} {inp}")
undef_macros.append(f"#undef INPUT_{int(i)}")
undef_inputs = [f"#undef INPUT_{int(i)}" for i in range(len(inputs))]
undef_outputs = [f"#undef OUTPUT_{int(i)}" for i in range(len(outputs))]
for i, out in enumerate(outputs):
define_macros.append(f"#define OUTPUT_{int(i)} {out}")
undef_macros.append(f"#undef OUTPUT_{int(i)}")
define_all = "\n".join(define_inputs + define_outputs)
undef_all = "\n".join(undef_inputs + undef_outputs)
return "\n".join(define_macros), "\n".join(undef_macros)
return define_all, undef_all
class ExternalCOp(COp):
......@@ -560,9 +551,10 @@ class ExternalCOp(COp):
define_macros.append(define_template % ("APPLY_SPECIFIC(str)", f"str##_{name}"))
undef_macros.append(undef_template % "APPLY_SPECIFIC")
for n, v in self.__get_op_params():
define_macros.append(define_template % (n, v))
undef_macros.append(undef_template % (n,))
define_macros.extend(
define_template % (n, v) for n, v in self.__get_op_params()
)
undef_macros.extend(undef_template % (n,) for n, _ in self.__get_op_params())
return "\n".join(define_macros), "\n".join(undef_macros)
......
......@@ -131,21 +131,19 @@ def create_numba_signature(
reduce_to_scalar: bool = False,
) -> numba.types.Type:
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
input_types = []
for inp in node_or_fgraph.inputs:
input_types.append(
input_types = [
get_numba_type(
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
)
)
for inp in node_or_fgraph.inputs
]
output_types = []
for out in node_or_fgraph.outputs:
output_types.append(
output_types = [
get_numba_type(
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
)
)
for out in node_or_fgraph.outputs
]
if len(output_types) > 1:
return numba.types.Tuple(output_types)(*input_types)
......
......@@ -520,9 +520,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
if length == 1 and shape and iter_length != 1 and not allow_bc:
raise ValueError("Broadcast not allowed.")
outputs = []
for dtype in output_dtypes:
outputs.append(np.empty(shape, dtype=dtype))
outputs = [np.empty(shape, dtype=dtype) for dtype in output_dtypes]
for idx in np.ndindex(shape):
vals = [input[idx] for input in inputs_bc]
......
......@@ -268,14 +268,14 @@ def numba_funcify_Scan(op, node, **kwargs):
output_taps = inner_in_names_to_output_taps.get(
outer_in_name, [tap_storage_size]
)
for out_tap in output_taps:
inner_out_to_outer_in_stmts.append(
inner_out_to_outer_in_stmts.extend(
idx_to_str(
storage_name,
out_tap,
size=storage_size_name,
allow_scalar=True,
)
for out_tap in output_taps
)
add_output_storage_post_proc_stmt(
......
......@@ -1111,9 +1111,8 @@ class VMLinker(LocalLinker):
# builds the list of prereqs induced by e.g. destroy_handler
ords = self.fgraph.orderings()
node_prereqs = []
node_output_size = []
node_output_size = [0] * len(nodes)
for i, node in enumerate(nodes):
node_output_size.append(0)
prereq_var_idxs = []
for prereq_node in ords.get(node, []):
prereq_var_idxs.extend([vars_idx[v] for v in prereq_node.outputs])
......
......@@ -1575,9 +1575,7 @@ class InRange(LogicalComparison):
def L_op(self, inputs, outputs, gout):
(x, low, hi) = inputs
(gz,) = gout
grads = []
for elem in [x, low, hi]:
grads.append(self.get_grad(elem))
grads = [self.get_grad(elem) for elem in [x, low, hi]]
return grads
......
......@@ -646,9 +646,7 @@ def scan(
# Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes
lengths_vec = []
for seq in scan_seqs:
lengths_vec.append(seq.shape[0])
lengths_vec = [seq.shape[0] for seq in scan_seqs]
if not isNaN_or_Inf_or_None(n_steps):
# ^ N_steps should also be considered
......
......@@ -1629,10 +1629,7 @@ class Alloc(COp):
return [node.inputs[1:]]
def connection_pattern(self, node):
rval = [[True]]
for ipt in node.inputs[1:]:
rval.append([False])
rval = [[True], *([False] for _ in node.inputs[1:])]
return rval
......@@ -1859,9 +1856,7 @@ class MakeVector(COp):
if self.dtype in discrete_dtypes:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
grads = []
for i, inp in enumerate(inputs):
grads.append(output_gradients[0][i])
grads = [output_gradients[0][i] for i in range(len(inputs))]
return grads
def R_op(self, inputs, eval_points):
......@@ -2514,12 +2509,10 @@ class Join(COp):
(out,) = outputs
fail = sub["fail"]
adtype = node.inputs[0].type.dtype_specs()[1]
copy_to_list = []
for i, inp in enumerate(tens):
copy_to_list.append(
f"""Py_INCREF({inp});
PyList_SetItem(list, {i}, (PyObject*){inp});"""
copy_to_list = (
f"""Py_INCREF({inp}); PyList_SetItem(list, {i}, (PyObject*){inp});"""
for i, inp in enumerate(tens)
)
copy_inputs_to_list = "\n".join(copy_to_list)
......@@ -3442,9 +3435,7 @@ class PermuteRowElements(Op):
shp_x = in_shapes[0]
shp_y = in_shapes[1]
assert len(shp_x) == len(shp_y)
out_shape = []
for i in range(len(shp_x)):
out_shape.append(maximum(shp_x[i], shp_y[i]))
out_shape = [maximum(sx, sy) for sx, sy in zip(shp_x, shp_y, strict=True)]
return [out_shape]
def grad(self, inp, grads):
......
......@@ -167,9 +167,8 @@ class Blockwise(Op):
batch_ndims = self.batch_ndim(node)
core_dims: dict[str, Any] = {}
batch_shapes = []
batch_shapes = [input_shape[:batch_ndims] for input_shape in input_shapes]
for input_shape, sig in zip(input_shapes, self.inputs_sig):
batch_shapes.append(input_shape[:batch_ndims])
core_shape = input_shape[batch_ndims:]
for core_dim, dim_name in zip(core_shape, sig):
......
......@@ -1161,8 +1161,10 @@ class Elemwise(OpenMPOp):
],
)
version.append(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs:
version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version())
version.extend(
get_scalar_type(dtype=i.type.dtype).c_code_cache_version()
for i in node.inputs + node.outputs
)
version.append(("openmp", self.openmp))
version.append(("openmp_elemwise_minsize", config.openmp_elemwise_minsize))
if all(version):
......@@ -1664,8 +1666,10 @@ class CAReduce(COp):
],
)
version.append(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs:
version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version())
version.extend(
get_scalar_type(dtype=i.type.dtype).c_code_cache_version()
for i in node.inputs + node.outputs
)
if all(version):
return tuple(version)
else:
......
......@@ -952,10 +952,7 @@ class Subtensor(COp):
return [first] + [DisconnectedType()()] * len(rest)
def connection_pattern(self, node):
rval = [[True]]
for ipt in node.inputs[1:]:
rval.append([False])
rval = [[True], *([False] for _ in node.inputs[1:])]
return rval
......@@ -1963,10 +1960,7 @@ class IncSubtensor(COp):
return self(eval_points[0], eval_points[1], *inputs[2:], return_list=True)
def connection_pattern(self, node):
rval = [[True], [True]]
for ipt in node.inputs[2:]:
rval.append([False])
rval = [[True], [True], *([False] for _ in node.inputs[2:])]
return rval
......@@ -2765,10 +2759,7 @@ class AdvancedSubtensor(Op):
out[0] = rval
def connection_pattern(self, node):
rval = [[True]]
for ipt in node.inputs[1:]:
rval.append([False])
rval = [[True], *([False] for _ in node.inputs[1:])]
return rval
......@@ -2912,10 +2903,7 @@ class AdvancedIncSubtensor(Op):
return [ishapes[0]]
def connection_pattern(self, node):
rval = [[True], [True]]
for ipt in node.inputs[2:]:
rval.append([False])
rval = [[True], [True], *([False] for _ in node.inputs[2:])]
return rval
......
......@@ -238,8 +238,7 @@ class Extend(COp):
# need to copy toAppend due to destroy_handler limitation
if toAppend:
o = out[0]
for i in toAppend:
o.append(_lessbroken_deepcopy(i))
o.extend(_lessbroken_deepcopy(i) for i in toAppend)
def __str__(self):
return self.__class__.__name__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论