提交 bb40791b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix the type hints in aesara.printing

上级 9d360389
...@@ -16,16 +16,20 @@ import sys ...@@ -16,16 +16,20 @@ import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.link.utils import get_destroy_dependencies from aesara.link.utils import get_destroy_dependencies
if TYPE_CHECKING:
from aesara.graph.fg import FunctionGraph
@contextmanager @contextmanager
def extended_open(filename, mode="r"): def extended_open(filename, mode="r"):
if filename == "<stdout>": if filename == "<stdout>":
...@@ -39,13 +43,13 @@ def extended_open(filename, mode="r"): ...@@ -39,13 +43,13 @@ def extended_open(filename, mode="r"):
logger = logging.getLogger("aesara.compile.profiling") logger = logging.getLogger("aesara.compile.profiling")
aesara_imported_time = time.time() aesara_imported_time: float = time.time()
total_fct_exec_time = 0.0 total_fct_exec_time: float = 0.0
total_graph_opt_time = 0.0 total_graph_opt_time: float = 0.0
total_time_linker = 0.0 total_time_linker: float = 0.0
_atexit_print_list: List = [] _atexit_print_list: List["ProfileStats"] = []
_atexit_registered = False _atexit_registered: bool = False
def _atexit_print_fn(): def _atexit_print_fn():
...@@ -180,7 +184,6 @@ def register_profiler_printer(fct): ...@@ -180,7 +184,6 @@ def register_profiler_printer(fct):
class ProfileStats: class ProfileStats:
""" """
Object to store runtime and memory profiling information for all of Object to store runtime and memory profiling information for all of
Aesara's operations: compilation, optimization, execution. Aesara's operations: compilation, optimization, execution.
...@@ -215,72 +218,68 @@ class ProfileStats: ...@@ -215,72 +218,68 @@ class ProfileStats:
# #
show_sum: bool = True show_sum: bool = True
compile_time = 0.0 compile_time: float = 0.0
# Total time spent in body of orig_function, # Total time spent in body of orig_function,
# dominated by graph optimization and compilation of C # dominated by graph optimization and compilation of C
# #
fct_call_time = 0.0 fct_call_time: float = 0.0
# The total time spent in Function.__call__ # The total time spent in Function.__call__
# #
fct_callcount = 0 fct_callcount: int = 0
# Number of calls to Function.__call__ # Number of calls to Function.__call__
# #
vm_call_time = 0.0 vm_call_time: float = 0.0
# Total time spent in Function.vm.__call__ # Total time spent in Function.vm.__call__
# #
apply_time = None apply_time: Optional[Dict[Union["FunctionGraph", Variable], float]] = None
# dict from `(FunctionGraph, Variable)` to float runtime
#
apply_callcount = None apply_callcount: Optional[Dict[Union["FunctionGraph", Variable], int]] = None
# dict from `(FunctionGraph, Variable)` to number of executions
#
apply_cimpl = None apply_cimpl: Optional[Dict[Apply, bool]] = None
# dict from node -> bool (1 if c, 0 if py) # dict from node -> bool (1 if c, 0 if py)
# #
message = None message: Optional[str] = None
# pretty string to print in summary, to identify this output # pretty string to print in summary, to identify this output
# #
variable_shape: Dict = {} variable_shape: Dict[Variable, Any] = {}
# Variable -> shapes # Variable -> shapes
# #
variable_strides: Dict = {} variable_strides: Dict[Variable, Any] = {}
# Variable -> strides # Variable -> strides
# #
variable_offset: Dict = {} variable_offset: Dict[Variable, Any] = {}
# Variable -> offset # Variable -> offset
# #
optimizer_time = 0.0 optimizer_time: float = 0.0
# time spent optimizing graph (FunctionMaker.__init__) # time spent optimizing graph (FunctionMaker.__init__)
validate_time = 0.0 validate_time: float = 0.0
# time spent in fgraph.validate # time spent in fgraph.validate
# This is a subset of optimizer_time that is dominated by toposort() # This is a subset of optimizer_time that is dominated by toposort()
# when the destorymap feature is included. # when the destorymap feature is included.
linker_time = 0.0 linker_time: float = 0.0
# time spent linking graph (FunctionMaker.create) # time spent linking graph (FunctionMaker.create)
import_time = 0.0 import_time: float = 0.0
# time spent in importing compiled python module. # time spent in importing compiled python module.
linker_node_make_thunks = 0.0 linker_node_make_thunks: float = 0.0
linker_make_thunk_time: Dict = {} linker_make_thunk_time: Dict = {}
line_width = config.profiling__output_line_width line_width = config.profiling__output_line_width
nb_nodes = -1 nb_nodes: int = -1
# The number of nodes in the graph. We need the information separately in # The number of nodes in the graph. We need the information separately in
# case we print the profile when the function wasn't executed, or if there # case we print the profile when the function wasn't executed, or if there
# is a lazy operation in the graph. # is a lazy operation in the graph.
......
...@@ -230,6 +230,7 @@ class VM(ABC): ...@@ -230,6 +230,7 @@ class VM(ABC):
self.call_counts = [0] * len(nodes) self.call_counts = [0] * len(nodes)
self.call_times = [0] * len(nodes) self.call_times = [0] * len(nodes)
self.time_thunks = False self.time_thunks = False
self.storage_map: Optional[StorageMapType] = None
@abstractmethod @abstractmethod
def __call__(self): def __call__(self):
......
差异被折叠。
...@@ -227,10 +227,6 @@ check_untyped_defs = False ...@@ -227,10 +227,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.printing]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.tensor.nnet.conv3d2d] [mypy-aesara.tensor.nnet.conv3d2d]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
...@@ -133,9 +133,8 @@ def test_debugprint(): ...@@ -133,9 +133,8 @@ def test_debugprint():
s = StringIO() s = StringIO()
debugprint(G, file=s) debugprint(G, file=s)
# test ids=int
s = StringIO() s = StringIO()
debugprint(G, file=s, ids="int") debugprint(G, file=s, id_type="int")
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = ( reference = (
...@@ -155,9 +154,8 @@ def test_debugprint(): ...@@ -155,9 +154,8 @@ def test_debugprint():
assert s == reference assert s == reference
# test ids=CHAR
s = StringIO() s = StringIO()
debugprint(G, file=s, ids="CHAR") debugprint(G, file=s, id_type="CHAR")
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = ( reference = (
...@@ -177,9 +175,8 @@ def test_debugprint(): ...@@ -177,9 +175,8 @@ def test_debugprint():
assert s == reference assert s == reference
# test ids=CHAR, stop_on_name=True
s = StringIO() s = StringIO()
debugprint(G, file=s, ids="CHAR", stop_on_name=True) debugprint(G, file=s, id_type="CHAR", stop_on_name=True)
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = ( reference = (
...@@ -197,9 +194,8 @@ def test_debugprint(): ...@@ -197,9 +194,8 @@ def test_debugprint():
assert s == reference assert s == reference
# test ids=
s = StringIO() s = StringIO()
debugprint(G, file=s, ids="") debugprint(G, file=s, id_type="")
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = ( reference = (
...@@ -221,7 +217,7 @@ def test_debugprint(): ...@@ -221,7 +217,7 @@ def test_debugprint():
# test print_storage=True # test print_storage=True
s = StringIO() s = StringIO()
debugprint(g, file=s, ids="", print_storage=True) debugprint(g, file=s, id_type="", print_storage=True)
s = s.getvalue() s = s.getvalue()
reference = ( reference = (
"\n".join( "\n".join(
...@@ -246,7 +242,7 @@ def test_debugprint(): ...@@ -246,7 +242,7 @@ def test_debugprint():
debugprint( debugprint(
aesara.function([A, B, D, J], A + (B.dot(J) - D), mode="FAST_RUN"), aesara.function([A, B, D, J], A + (B.dot(J) - D), mode="FAST_RUN"),
file=s, file=s,
ids="", id_type="",
print_destroy_map=True, print_destroy_map=True,
print_view_map=True, print_view_map=True,
) )
...@@ -270,7 +266,7 @@ def test_debugprint(): ...@@ -270,7 +266,7 @@ def test_debugprint():
] ]
def test_debugprint_ids(): def test_debugprint_id_type():
a_at = dvector() a_at = dvector()
b_at = dmatrix() b_at = dmatrix()
...@@ -278,7 +274,7 @@ def test_debugprint_ids(): ...@@ -278,7 +274,7 @@ def test_debugprint_ids():
e_at = d_at + a_at e_at = d_at + a_at
s = StringIO() s = StringIO()
debugprint(e_at, ids="auto", file=s) debugprint(e_at, id_type="auto", file=s)
s = s.getvalue() s = s.getvalue()
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}] exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论