提交 cc1a1cbb authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Use defaultdict and Counter in profiling.py

上级 6e7a4310
......@@ -14,9 +14,9 @@ import logging
import operator
import sys
import time
from collections import defaultdict
from collections import Counter, defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
import numpy as np
......@@ -204,8 +204,8 @@ class ProfileStats:
self.fct_call_time = 0.0
self.fct_callcount = 0
self.vm_call_time = 0.0
self.apply_time = {}
self.apply_callcount = {}
self.apply_time = defaultdict(float)
self.apply_callcount = Counter()
# self.apply_cimpl = None
# self.message = None
......@@ -234,9 +234,9 @@ class ProfileStats:
# Total time spent in Function.vm.__call__
#
apply_time: dict[Union["FunctionGraph", Variable], float] | None = None
apply_time: dict[tuple["FunctionGraph", Apply], float]
apply_callcount: dict[Union["FunctionGraph", Variable], int] | None = None
apply_callcount: dict[tuple["FunctionGraph", Apply], int]
apply_cimpl: dict[Apply, bool] | None = None
# dict from node -> bool (1 if c, 0 if py)
......@@ -292,10 +292,9 @@ class ProfileStats:
# param is called flag_time_thunks because most other attributes with time
# in the name are times *of* something, rather than configuration flags.
def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs):
self.apply_callcount = {}
self.apply_callcount = Counter()
self.output_size = {}
# Keys are `(FunctionGraph, Variable)`
self.apply_time = {}
self.apply_time = defaultdict(float)
self.apply_cimpl = {}
self.variable_shape = {}
self.variable_strides = {}
......@@ -320,12 +319,10 @@ class ProfileStats:
"""
# timing is stored by node, we compute timing by class on demand
rval = {}
for (fgraph, node), t in self.apply_time.items():
typ = type(node.op)
rval.setdefault(typ, 0)
rval[typ] += t
return rval
rval = defaultdict(float)
for (_fgraph, node), t in self.apply_time.items():
rval[type(node.op)] += t
return dict(rval)
def class_callcount(self):
"""
......@@ -333,24 +330,18 @@ class ProfileStats:
"""
# timing is stored by node, we compute timing by class on demand
rval = {}
for (fgraph, node), count in self.apply_callcount.items():
typ = type(node.op)
rval.setdefault(typ, 0)
rval[typ] += count
rval = Counter()
for (_fgraph, node), count in self.apply_callcount.items():
rval[type(node.op)] += count
return rval
def class_nodes(self):
def class_nodes(self) -> Counter:
"""
dict op -> total number of nodes
"""
# timing is stored by node, we compute timing by class on demand
rval = {}
for (fgraph, node), count in self.apply_callcount.items():
typ = type(node.op)
rval.setdefault(typ, 0)
rval[typ] += 1
rval = Counter(type(node.op) for _fgraph, node in self.apply_callcount)
return rval
def class_impl(self):
......@@ -360,12 +351,9 @@ class ProfileStats:
"""
# timing is stored by node, we compute timing by class on demand
rval = {}
for fgraph, node in self.apply_callcount:
for _fgraph, node in self.apply_callcount:
typ = type(node.op)
if self.apply_cimpl[node]:
impl = "C "
else:
impl = "Py"
impl = "C " if self.apply_cimpl[node] else "Py"
rval.setdefault(typ, impl)
if rval[typ] != impl and len(rval[typ]) == 2:
rval[typ] += impl
......@@ -377,11 +365,10 @@ class ProfileStats:
"""
# timing is stored by node, we compute timing by Op on demand
rval = {}
rval = defaultdict(float)
for (fgraph, node), t in self.apply_time.items():
rval.setdefault(node.op, 0)
rval[node.op] += t
return rval
return dict(rval)
def fill_node_total_time(self, fgraph, node, total_times):
"""
......@@ -414,9 +401,8 @@ class ProfileStats:
"""
# timing is stored by node, we compute timing by Op on demand
rval = {}
rval = Counter()
for (fgraph, node), count in self.apply_callcount.items():
rval.setdefault(node.op, 0)
rval[node.op] += count
return rval
......@@ -426,10 +412,7 @@ class ProfileStats:
"""
# timing is stored by node, we compute timing by Op on demand
rval = {}
for (fgraph, node), count in self.apply_callcount.items():
rval.setdefault(node.op, 0)
rval[node.op] += 1
rval = Counter(node.op for _fgraph, node in self.apply_callcount)
return rval
def op_impl(self):
......
......@@ -246,10 +246,8 @@ class VM(ABC):
for node, thunk, t, c in zip(
self.nodes, self.thunks, self.call_times, self.call_counts
):
profile.apply_time.setdefault((self.fgraph, node), 0.0)
profile.apply_time[(self.fgraph, node)] += t
profile.apply_callcount.setdefault((self.fgraph, node), 0)
profile.apply_callcount[(self.fgraph, node)] += c
profile.apply_cimpl[node] = hasattr(thunk, "cthunk")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论