提交 cc1a1cbb authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Use defaultdict and Counter in profiling.py

上级 6e7a4310
...@@ -14,9 +14,9 @@ import logging ...@@ -14,9 +14,9 @@ import logging
import operator import operator
import sys import sys
import time import time
from collections import defaultdict from collections import Counter, defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Union from typing import TYPE_CHECKING, Any
import numpy as np import numpy as np
...@@ -204,8 +204,8 @@ class ProfileStats: ...@@ -204,8 +204,8 @@ class ProfileStats:
self.fct_call_time = 0.0 self.fct_call_time = 0.0
self.fct_callcount = 0 self.fct_callcount = 0
self.vm_call_time = 0.0 self.vm_call_time = 0.0
self.apply_time = {} self.apply_time = defaultdict(float)
self.apply_callcount = {} self.apply_callcount = Counter()
# self.apply_cimpl = None # self.apply_cimpl = None
# self.message = None # self.message = None
...@@ -234,9 +234,9 @@ class ProfileStats: ...@@ -234,9 +234,9 @@ class ProfileStats:
# Total time spent in Function.vm.__call__ # Total time spent in Function.vm.__call__
# #
apply_time: dict[Union["FunctionGraph", Variable], float] | None = None apply_time: dict[tuple["FunctionGraph", Apply], float]
apply_callcount: dict[Union["FunctionGraph", Variable], int] | None = None apply_callcount: dict[tuple["FunctionGraph", Apply], int]
apply_cimpl: dict[Apply, bool] | None = None apply_cimpl: dict[Apply, bool] | None = None
# dict from node -> bool (1 if c, 0 if py) # dict from node -> bool (1 if c, 0 if py)
...@@ -292,10 +292,9 @@ class ProfileStats: ...@@ -292,10 +292,9 @@ class ProfileStats:
# param is called flag_time_thunks because most other attributes with time # param is called flag_time_thunks because most other attributes with time
# in the name are times *of* something, rather than configuration flags. # in the name are times *of* something, rather than configuration flags.
def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs): def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs):
self.apply_callcount = {} self.apply_callcount = Counter()
self.output_size = {} self.output_size = {}
# Keys are `(FunctionGraph, Variable)` self.apply_time = defaultdict(float)
self.apply_time = {}
self.apply_cimpl = {} self.apply_cimpl = {}
self.variable_shape = {} self.variable_shape = {}
self.variable_strides = {} self.variable_strides = {}
...@@ -320,12 +319,10 @@ class ProfileStats: ...@@ -320,12 +319,10 @@ class ProfileStats:
""" """
# timing is stored by node, we compute timing by class on demand # timing is stored by node, we compute timing by class on demand
rval = {} rval = defaultdict(float)
for (fgraph, node), t in self.apply_time.items(): for (_fgraph, node), t in self.apply_time.items():
typ = type(node.op) rval[type(node.op)] += t
rval.setdefault(typ, 0) return dict(rval)
rval[typ] += t
return rval
def class_callcount(self): def class_callcount(self):
""" """
...@@ -333,24 +330,18 @@ class ProfileStats: ...@@ -333,24 +330,18 @@ class ProfileStats:
""" """
# timing is stored by node, we compute timing by class on demand # timing is stored by node, we compute timing by class on demand
rval = {} rval = Counter()
for (fgraph, node), count in self.apply_callcount.items(): for (_fgraph, node), count in self.apply_callcount.items():
typ = type(node.op) rval[type(node.op)] += count
rval.setdefault(typ, 0)
rval[typ] += count
return rval return rval
def class_nodes(self): def class_nodes(self) -> Counter:
""" """
dict op -> total number of nodes dict op -> total number of nodes
""" """
# timing is stored by node, we compute timing by class on demand # timing is stored by node, we compute timing by class on demand
rval = {} rval = Counter(type(node.op) for _fgraph, node in self.apply_callcount)
for (fgraph, node), count in self.apply_callcount.items():
typ = type(node.op)
rval.setdefault(typ, 0)
rval[typ] += 1
return rval return rval
def class_impl(self): def class_impl(self):
...@@ -360,12 +351,9 @@ class ProfileStats: ...@@ -360,12 +351,9 @@ class ProfileStats:
""" """
# timing is stored by node, we compute timing by class on demand # timing is stored by node, we compute timing by class on demand
rval = {} rval = {}
for fgraph, node in self.apply_callcount: for _fgraph, node in self.apply_callcount:
typ = type(node.op) typ = type(node.op)
if self.apply_cimpl[node]: impl = "C " if self.apply_cimpl[node] else "Py"
impl = "C "
else:
impl = "Py"
rval.setdefault(typ, impl) rval.setdefault(typ, impl)
if rval[typ] != impl and len(rval[typ]) == 2: if rval[typ] != impl and len(rval[typ]) == 2:
rval[typ] += impl rval[typ] += impl
...@@ -377,11 +365,10 @@ class ProfileStats: ...@@ -377,11 +365,10 @@ class ProfileStats:
""" """
# timing is stored by node, we compute timing by Op on demand # timing is stored by node, we compute timing by Op on demand
rval = {} rval = defaultdict(float)
for (fgraph, node), t in self.apply_time.items(): for (fgraph, node), t in self.apply_time.items():
rval.setdefault(node.op, 0)
rval[node.op] += t rval[node.op] += t
return rval return dict(rval)
def fill_node_total_time(self, fgraph, node, total_times): def fill_node_total_time(self, fgraph, node, total_times):
""" """
...@@ -414,9 +401,8 @@ class ProfileStats: ...@@ -414,9 +401,8 @@ class ProfileStats:
""" """
# timing is stored by node, we compute timing by Op on demand # timing is stored by node, we compute timing by Op on demand
rval = {} rval = Counter()
for (fgraph, node), count in self.apply_callcount.items(): for (fgraph, node), count in self.apply_callcount.items():
rval.setdefault(node.op, 0)
rval[node.op] += count rval[node.op] += count
return rval return rval
...@@ -426,10 +412,7 @@ class ProfileStats: ...@@ -426,10 +412,7 @@ class ProfileStats:
""" """
# timing is stored by node, we compute timing by Op on demand # timing is stored by node, we compute timing by Op on demand
rval = {} rval = Counter(node.op for _fgraph, node in self.apply_callcount)
for (fgraph, node), count in self.apply_callcount.items():
rval.setdefault(node.op, 0)
rval[node.op] += 1
return rval return rval
def op_impl(self): def op_impl(self):
......
...@@ -246,10 +246,8 @@ class VM(ABC): ...@@ -246,10 +246,8 @@ class VM(ABC):
for node, thunk, t, c in zip( for node, thunk, t, c in zip(
self.nodes, self.thunks, self.call_times, self.call_counts self.nodes, self.thunks, self.call_times, self.call_counts
): ):
profile.apply_time.setdefault((self.fgraph, node), 0.0)
profile.apply_time[(self.fgraph, node)] += t profile.apply_time[(self.fgraph, node)] += t
profile.apply_callcount.setdefault((self.fgraph, node), 0)
profile.apply_callcount[(self.fgraph, node)] += c profile.apply_callcount[(self.fgraph, node)] += c
profile.apply_cimpl[node] = hasattr(thunk, "cthunk") profile.apply_cimpl[node] = hasattr(thunk, "cthunk")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论