提交 6e7a4310 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Use defaultdict in graph/rewriting/basic.py

上级 d3bc076f
......@@ -10,7 +10,7 @@ import sys
import time
import traceback
import warnings
from collections import UserList, defaultdict, deque
from collections import Counter, UserList, defaultdict, deque
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable as IterableType
from functools import _compose_mro, partial, reduce # type: ignore
......@@ -1153,8 +1153,8 @@ class OpToRewriterTracker:
r"""A container that maps `NodeRewriter`\s to `Op` instances and `Op`-type inheritance."""
def __init__(self):
self.tracked_instances: dict[Op, list[NodeRewriter]] = {}
self.tracked_types: dict[type, list[NodeRewriter]] = {}
self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list)
self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list)
self.untracked_rewrites: list[NodeRewriter] = []
def add_tracker(self, rw: NodeRewriter):
......@@ -1166,9 +1166,9 @@ class OpToRewriterTracker:
else:
for c in tracks:
if isinstance(c, type):
self.tracked_types.setdefault(c, []).append(rw)
self.tracked_types[c].append(rw)
else:
self.tracked_instances.setdefault(c, []).append(rw)
self.tracked_instances[c].append(rw)
def _find_impl(self, cls) -> list[NodeRewriter]:
r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
......@@ -1250,22 +1250,16 @@ class SequentialNodeRewriter(NodeRewriter):
self.profile = profile
if self.profile:
self.time_rewrites: dict[Rewriter, float] = {}
self.process_count: dict[Rewriter, int] = {}
self.applied_true: dict[Rewriter, int] = {}
self.node_created: dict[Rewriter, int] = {}
self.time_rewrites: dict[Rewriter, float] = defaultdict(float)
self.process_count: dict[Rewriter, int] = Counter()
self.applied_true: dict[Rewriter, int] = Counter()
self.node_created: dict[Rewriter, int] = Counter()
self.tracker = OpToRewriterTracker()
for o in self.rewrites:
self.tracker.add_tracker(o)
if self.profile:
self.time_rewrites.setdefault(o, 0.0)
self.process_count.setdefault(o, 0)
self.applied_true.setdefault(o, 0)
self.node_created.setdefault(o, 0)
def __str__(self):
return getattr(
self,
......@@ -2316,7 +2310,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
changed = True
max_use_abort = False
rewriter_name = None
global_process_count = {}
global_process_count = Counter()
start_nb_nodes = len(fgraph.apply_nodes)
max_nb_nodes = len(fgraph.apply_nodes)
max_use = max_nb_nodes * self.max_use_ratio
......@@ -2324,22 +2318,21 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
loop_timing = []
loop_process_count = []
global_rewriter_timing = []
time_rewriters = {}
time_rewriters = defaultdict(float)
io_toposort_timing = []
nb_nodes = []
node_created = {}
node_created = Counter()
global_sub_profs = []
final_sub_profs = []
cleanup_sub_profs = []
for rewriter in (
self.global_rewriters
+ list(self.get_node_rewriters())
+ self.final_rewriters
+ self.cleanup_rewriters
):
global_process_count.setdefault(rewriter, 0)
time_rewriters.setdefault(rewriter, 0)
node_created.setdefault(rewriter, 0)
for rewriter in [
*self.global_rewriters,
*self.get_node_rewriters(),
*self.final_rewriters,
*self.cleanup_rewriters,
]:
time_rewriters[rewriter] += 0
def apply_cleanup(profs_dict):
changed = False
......@@ -2351,7 +2344,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
time_rewriters[crewriter] += time.perf_counter() - t_rewrite
profs_dict[crewriter].append(sub_prof)
if change_tracker.changed:
process_count.setdefault(crewriter, 0)
process_count[crewriter] += 1
global_process_count[crewriter] += 1
changed = True
......@@ -2359,7 +2351,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
return changed
while changed and not max_use_abort:
process_count = {}
process_count = Counter()
t0 = time.perf_counter()
changed = False
iter_cleanup_sub_profs = {}
......@@ -2376,7 +2368,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
time_rewriters[grewrite] += time.perf_counter() - t_rewrite
sub_profs.append(sub_prof)
if change_tracker.changed:
process_count.setdefault(grewrite, 0)
process_count[grewrite] += 1
global_process_count[grewrite] += 1
changed = True
......@@ -2431,7 +2422,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite
if not node_rewriter_change:
continue
process_count.setdefault(node_rewriter, 0)
process_count[node_rewriter] += 1
global_process_count[node_rewriter] += 1
changed = True
......@@ -2459,7 +2449,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
time_rewriters[grewrite] += time.perf_counter() - t_rewrite
sub_profs.append(sub_prof)
if change_tracker.changed:
process_count.setdefault(grewrite, 0)
process_count[grewrite] += 1
global_process_count[grewrite] += 1
changed = True
......@@ -2514,7 +2503,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
(start_nb_nodes, end_nb_nodes, max_nb_nodes),
global_rewriter_timing,
nb_nodes,
time_rewriters,
dict(time_rewriters),
io_toposort_timing,
node_created,
global_sub_profs,
......@@ -2597,14 +2586,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
count_rewrite = []
not_used = []
not_used_time = 0
process_count = {}
for o in (
rewrite.global_rewriters
+ list(rewrite.get_node_rewriters())
+ list(rewrite.final_rewriters)
+ list(rewrite.cleanup_rewriters)
):
process_count.setdefault(o, 0)
process_count = Counter()
for count in loop_process_count:
for o, v in count.items():
process_count[o] += v
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论