提交 6e7a4310 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Use defaultdict in graph/rewriting/basic.py

上级 d3bc076f
...@@ -10,7 +10,7 @@ import sys ...@@ -10,7 +10,7 @@ import sys
import time import time
import traceback import traceback
import warnings import warnings
from collections import UserList, defaultdict, deque from collections import Counter, UserList, defaultdict, deque
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable as IterableType from collections.abc import Iterable as IterableType
from functools import _compose_mro, partial, reduce # type: ignore from functools import _compose_mro, partial, reduce # type: ignore
...@@ -1153,8 +1153,8 @@ class OpToRewriterTracker: ...@@ -1153,8 +1153,8 @@ class OpToRewriterTracker:
r"""A container that maps `NodeRewriter`\s to `Op` instances and `Op`-type inheritance.""" r"""A container that maps `NodeRewriter`\s to `Op` instances and `Op`-type inheritance."""
def __init__(self): def __init__(self):
self.tracked_instances: dict[Op, list[NodeRewriter]] = {} self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list)
self.tracked_types: dict[type, list[NodeRewriter]] = {} self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list)
self.untracked_rewrites: list[NodeRewriter] = [] self.untracked_rewrites: list[NodeRewriter] = []
def add_tracker(self, rw: NodeRewriter): def add_tracker(self, rw: NodeRewriter):
...@@ -1166,9 +1166,9 @@ class OpToRewriterTracker: ...@@ -1166,9 +1166,9 @@ class OpToRewriterTracker:
else: else:
for c in tracks: for c in tracks:
if isinstance(c, type): if isinstance(c, type):
self.tracked_types.setdefault(c, []).append(rw) self.tracked_types[c].append(rw)
else: else:
self.tracked_instances.setdefault(c, []).append(rw) self.tracked_instances[c].append(rw)
def _find_impl(self, cls) -> list[NodeRewriter]: def _find_impl(self, cls) -> list[NodeRewriter]:
r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance. r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
...@@ -1250,22 +1250,16 @@ class SequentialNodeRewriter(NodeRewriter): ...@@ -1250,22 +1250,16 @@ class SequentialNodeRewriter(NodeRewriter):
self.profile = profile self.profile = profile
if self.profile: if self.profile:
self.time_rewrites: dict[Rewriter, float] = {} self.time_rewrites: dict[Rewriter, float] = defaultdict(float)
self.process_count: dict[Rewriter, int] = {} self.process_count: dict[Rewriter, int] = Counter()
self.applied_true: dict[Rewriter, int] = {} self.applied_true: dict[Rewriter, int] = Counter()
self.node_created: dict[Rewriter, int] = {} self.node_created: dict[Rewriter, int] = Counter()
self.tracker = OpToRewriterTracker() self.tracker = OpToRewriterTracker()
for o in self.rewrites: for o in self.rewrites:
self.tracker.add_tracker(o) self.tracker.add_tracker(o)
if self.profile:
self.time_rewrites.setdefault(o, 0.0)
self.process_count.setdefault(o, 0)
self.applied_true.setdefault(o, 0)
self.node_created.setdefault(o, 0)
def __str__(self): def __str__(self):
return getattr( return getattr(
self, self,
...@@ -2316,7 +2310,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2316,7 +2310,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
changed = True changed = True
max_use_abort = False max_use_abort = False
rewriter_name = None rewriter_name = None
global_process_count = {} global_process_count = Counter()
start_nb_nodes = len(fgraph.apply_nodes) start_nb_nodes = len(fgraph.apply_nodes)
max_nb_nodes = len(fgraph.apply_nodes) max_nb_nodes = len(fgraph.apply_nodes)
max_use = max_nb_nodes * self.max_use_ratio max_use = max_nb_nodes * self.max_use_ratio
...@@ -2324,22 +2318,21 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2324,22 +2318,21 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
loop_timing = [] loop_timing = []
loop_process_count = [] loop_process_count = []
global_rewriter_timing = [] global_rewriter_timing = []
time_rewriters = {} time_rewriters = defaultdict(float)
io_toposort_timing = [] io_toposort_timing = []
nb_nodes = [] nb_nodes = []
node_created = {} node_created = Counter()
global_sub_profs = [] global_sub_profs = []
final_sub_profs = [] final_sub_profs = []
cleanup_sub_profs = [] cleanup_sub_profs = []
for rewriter in (
self.global_rewriters for rewriter in [
+ list(self.get_node_rewriters()) *self.global_rewriters,
+ self.final_rewriters *self.get_node_rewriters(),
+ self.cleanup_rewriters *self.final_rewriters,
): *self.cleanup_rewriters,
global_process_count.setdefault(rewriter, 0) ]:
time_rewriters.setdefault(rewriter, 0) time_rewriters[rewriter] += 0
node_created.setdefault(rewriter, 0)
def apply_cleanup(profs_dict): def apply_cleanup(profs_dict):
changed = False changed = False
...@@ -2351,7 +2344,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2351,7 +2344,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
time_rewriters[crewriter] += time.perf_counter() - t_rewrite time_rewriters[crewriter] += time.perf_counter() - t_rewrite
profs_dict[crewriter].append(sub_prof) profs_dict[crewriter].append(sub_prof)
if change_tracker.changed: if change_tracker.changed:
process_count.setdefault(crewriter, 0)
process_count[crewriter] += 1 process_count[crewriter] += 1
global_process_count[crewriter] += 1 global_process_count[crewriter] += 1
changed = True changed = True
...@@ -2359,7 +2351,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2359,7 +2351,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
return changed return changed
while changed and not max_use_abort: while changed and not max_use_abort:
process_count = {} process_count = Counter()
t0 = time.perf_counter() t0 = time.perf_counter()
changed = False changed = False
iter_cleanup_sub_profs = {} iter_cleanup_sub_profs = {}
...@@ -2376,7 +2368,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2376,7 +2368,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
time_rewriters[grewrite] += time.perf_counter() - t_rewrite time_rewriters[grewrite] += time.perf_counter() - t_rewrite
sub_profs.append(sub_prof) sub_profs.append(sub_prof)
if change_tracker.changed: if change_tracker.changed:
process_count.setdefault(grewrite, 0)
process_count[grewrite] += 1 process_count[grewrite] += 1
global_process_count[grewrite] += 1 global_process_count[grewrite] += 1
changed = True changed = True
...@@ -2431,7 +2422,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2431,7 +2422,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite
if not node_rewriter_change: if not node_rewriter_change:
continue continue
process_count.setdefault(node_rewriter, 0)
process_count[node_rewriter] += 1 process_count[node_rewriter] += 1
global_process_count[node_rewriter] += 1 global_process_count[node_rewriter] += 1
changed = True changed = True
...@@ -2459,7 +2449,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2459,7 +2449,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
time_rewriters[grewrite] += time.perf_counter() - t_rewrite time_rewriters[grewrite] += time.perf_counter() - t_rewrite
sub_profs.append(sub_prof) sub_profs.append(sub_prof)
if change_tracker.changed: if change_tracker.changed:
process_count.setdefault(grewrite, 0)
process_count[grewrite] += 1 process_count[grewrite] += 1
global_process_count[grewrite] += 1 global_process_count[grewrite] += 1
changed = True changed = True
...@@ -2514,7 +2503,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2514,7 +2503,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
(start_nb_nodes, end_nb_nodes, max_nb_nodes), (start_nb_nodes, end_nb_nodes, max_nb_nodes),
global_rewriter_timing, global_rewriter_timing,
nb_nodes, nb_nodes,
time_rewriters, dict(time_rewriters),
io_toposort_timing, io_toposort_timing,
node_created, node_created,
global_sub_profs, global_sub_profs,
...@@ -2597,14 +2586,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2597,14 +2586,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
count_rewrite = [] count_rewrite = []
not_used = [] not_used = []
not_used_time = 0 not_used_time = 0
process_count = {} process_count = Counter()
for o in (
rewrite.global_rewriters
+ list(rewrite.get_node_rewriters())
+ list(rewrite.final_rewriters)
+ list(rewrite.cleanup_rewriters)
):
process_count.setdefault(o, 0)
for count in loop_process_count: for count in loop_process_count:
for o, v in count.items(): for o, v in count.items():
process_count[o] += v process_count[o] += v
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论