提交 bb40791b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix the type hints in aesara.printing

上级 9d360389
......@@ -16,16 +16,20 @@ import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import numpy as np
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.basic import Apply, Constant, Variable
from aesara.link.utils import get_destroy_dependencies
if TYPE_CHECKING:
from aesara.graph.fg import FunctionGraph
@contextmanager
def extended_open(filename, mode="r"):
if filename == "<stdout>":
......@@ -39,13 +43,13 @@ def extended_open(filename, mode="r"):
logger = logging.getLogger("aesara.compile.profiling")
aesara_imported_time = time.time()
total_fct_exec_time = 0.0
total_graph_opt_time = 0.0
total_time_linker = 0.0
aesara_imported_time: float = time.time()
total_fct_exec_time: float = 0.0
total_graph_opt_time: float = 0.0
total_time_linker: float = 0.0
_atexit_print_list: List = []
_atexit_registered = False
_atexit_print_list: List["ProfileStats"] = []
_atexit_registered: bool = False
def _atexit_print_fn():
......@@ -180,7 +184,6 @@ def register_profiler_printer(fct):
class ProfileStats:
"""
Object to store runtime and memory profiling information for all of
Aesara's operations: compilation, optimization, execution.
......@@ -215,72 +218,68 @@ class ProfileStats:
#
show_sum: bool = True
compile_time = 0.0
compile_time: float = 0.0
# Total time spent in body of orig_function,
# dominated by graph optimization and compilation of C
#
fct_call_time = 0.0
fct_call_time: float = 0.0
# The total time spent in Function.__call__
#
fct_callcount = 0
fct_callcount: int = 0
# Number of calls to Function.__call__
#
vm_call_time = 0.0
vm_call_time: float = 0.0
# Total time spent in Function.vm.__call__
#
apply_time = None
# dict from `(FunctionGraph, Variable)` to float runtime
#
apply_time: Optional[Dict[Union["FunctionGraph", Variable], float]] = None
apply_callcount = None
# dict from `(FunctionGraph, Variable)` to number of executions
#
apply_callcount: Optional[Dict[Union["FunctionGraph", Variable], int]] = None
apply_cimpl = None
apply_cimpl: Optional[Dict[Apply, bool]] = None
# dict from node -> bool (1 if c, 0 if py)
#
message = None
message: Optional[str] = None
# pretty string to print in summary, to identify this output
#
variable_shape: Dict = {}
variable_shape: Dict[Variable, Any] = {}
# Variable -> shapes
#
variable_strides: Dict = {}
variable_strides: Dict[Variable, Any] = {}
# Variable -> strides
#
variable_offset: Dict = {}
variable_offset: Dict[Variable, Any] = {}
# Variable -> offset
#
optimizer_time = 0.0
optimizer_time: float = 0.0
# time spent optimizing graph (FunctionMaker.__init__)
validate_time = 0.0
validate_time: float = 0.0
# time spent in fgraph.validate
# This is a subset of optimizer_time that is dominated by toposort()
# when the destorymap feature is included.
linker_time = 0.0
linker_time: float = 0.0
# time spent linking graph (FunctionMaker.create)
import_time = 0.0
import_time: float = 0.0
# time spent in importing compiled python module.
linker_node_make_thunks = 0.0
linker_node_make_thunks: float = 0.0
linker_make_thunk_time: Dict = {}
line_width = config.profiling__output_line_width
nb_nodes = -1
nb_nodes: int = -1
# The number of nodes in the graph. We need the information separately in
# case we print the profile when the function wasn't executed, or if there
# is a lazy operation in the graph.
......
......@@ -230,6 +230,7 @@ class VM(ABC):
self.call_counts = [0] * len(nodes)
self.call_times = [0] * len(nodes)
self.time_thunks = False
self.storage_map: Optional[StorageMapType] = None
@abstractmethod
def __call__(self):
......
差异被折叠。
......@@ -227,10 +227,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.printing]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.tensor.nnet.conv3d2d]
ignore_errors = True
check_untyped_defs = False
......
......@@ -133,9 +133,8 @@ def test_debugprint():
s = StringIO()
debugprint(G, file=s)
# test ids=int
s = StringIO()
debugprint(G, file=s, ids="int")
debugprint(G, file=s, id_type="int")
s = s.getvalue()
# The additional white space are needed!
reference = (
......@@ -155,9 +154,8 @@ def test_debugprint():
assert s == reference
# test ids=CHAR
s = StringIO()
debugprint(G, file=s, ids="CHAR")
debugprint(G, file=s, id_type="CHAR")
s = s.getvalue()
# The additional white space are needed!
reference = (
......@@ -177,9 +175,8 @@ def test_debugprint():
assert s == reference
# test ids=CHAR, stop_on_name=True
s = StringIO()
debugprint(G, file=s, ids="CHAR", stop_on_name=True)
debugprint(G, file=s, id_type="CHAR", stop_on_name=True)
s = s.getvalue()
# The additional white space are needed!
reference = (
......@@ -197,9 +194,8 @@ def test_debugprint():
assert s == reference
# test ids=
s = StringIO()
debugprint(G, file=s, ids="")
debugprint(G, file=s, id_type="")
s = s.getvalue()
# The additional white space are needed!
reference = (
......@@ -221,7 +217,7 @@ def test_debugprint():
# test print_storage=True
s = StringIO()
debugprint(g, file=s, ids="", print_storage=True)
debugprint(g, file=s, id_type="", print_storage=True)
s = s.getvalue()
reference = (
"\n".join(
......@@ -246,7 +242,7 @@ def test_debugprint():
debugprint(
aesara.function([A, B, D, J], A + (B.dot(J) - D), mode="FAST_RUN"),
file=s,
ids="",
id_type="",
print_destroy_map=True,
print_view_map=True,
)
......@@ -270,7 +266,7 @@ def test_debugprint():
]
def test_debugprint_ids():
def test_debugprint_id_type():
a_at = dvector()
b_at = dmatrix()
......@@ -278,7 +274,7 @@ def test_debugprint_ids():
e_at = d_at + a_at
s = StringIO()
debugprint(e_at, ids="auto", file=s)
debugprint(e_at, id_type="auto", file=s)
s = s.getvalue()
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论