提交 d7c78755 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Fix PERF403: dict comprehensions when appropriate

上级 f2bf051c
...@@ -1055,12 +1055,7 @@ class VMLinker(LocalLinker): ...@@ -1055,12 +1055,7 @@ class VMLinker(LocalLinker):
for v in self.fgraph.inputs + self.fgraph.outputs: for v in self.fgraph.inputs + self.fgraph.outputs:
vars_idx.setdefault(v, len(vars_idx)) vars_idx.setdefault(v, len(vars_idx))
nodes_idx_inv = {} vars_idx_inv = {i: var for var, i in vars_idx.items()}
vars_idx_inv = {}
for node, i in nodes_idx.items():
nodes_idx_inv[i] = node
for var, i in vars_idx.items():
vars_idx_inv[i] = var
# put storage_map and compute_map into a int-based scheme # put storage_map and compute_map into a int-based scheme
storage_map_list = [ storage_map_list = [
......
...@@ -4434,8 +4434,7 @@ class Compositef32: ...@@ -4434,8 +4434,7 @@ class Compositef32:
) )
# make sure we don't produce any float16. # make sure we don't produce any float16.
assert not any(o.dtype == "float16" for o in new_node.outputs) assert not any(o.dtype == "float16" for o in new_node.outputs)
for o, no in zip(node.outputs, new_node.outputs): mapping.update(zip(node.outputs, new_node.outputs))
mapping[o] = no
new_ins = [mapping[inp] for inp in fgraph.inputs] new_ins = [mapping[inp] for inp in fgraph.inputs]
new_outs = [mapping[out] for out in fgraph.outputs] new_outs = [mapping[out] for out in fgraph.outputs]
......
...@@ -2240,8 +2240,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2240,8 +2240,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Non-sequences have a direct equivalent from self.inner_inputs in # Non-sequences have a direct equivalent from self.inner_inputs in
# node.inputs # node.inputs
inner_non_sequences = self.inner_inputs[len(seqs_shape) + len(outs_shape) :] inner_non_sequences = self.inner_inputs[len(seqs_shape) + len(outs_shape) :]
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]): out_equivalent.update(zip(inner_non_sequences, node.inputs[offset:]))
out_equivalent[in_ns] = out_ns
if info.as_while: if info.as_while:
self_outs = self.inner_outputs[:-1] self_outs = self.inner_outputs[:-1]
......
...@@ -1623,9 +1623,7 @@ def scan_save_mem(fgraph, node): ...@@ -1623,9 +1623,7 @@ def scan_save_mem(fgraph, node):
(inps, outs, info, node_ins, compress_map) = compress_outs( (inps, outs, info, node_ins, compress_map) = compress_outs(
op, not_required, nw_inputs op, not_required, nw_inputs
) )
inv_compress_map = {} inv_compress_map = {v: k for k, v in compress_map.items()}
for k, v in compress_map.items():
inv_compress_map[v] = k
# 3.6 Compose the new scan # 3.6 Compose the new scan
# TODO: currently we don't support scan with 0 step. So # TODO: currently we don't support scan with 0 step. So
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论