提交 bb40791b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix the type hints in aesara.printing

上级 9d360389
...@@ -16,16 +16,20 @@ import sys ...@@ -16,16 +16,20 @@ import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.link.utils import get_destroy_dependencies from aesara.link.utils import get_destroy_dependencies
if TYPE_CHECKING:
from aesara.graph.fg import FunctionGraph
@contextmanager @contextmanager
def extended_open(filename, mode="r"): def extended_open(filename, mode="r"):
if filename == "<stdout>": if filename == "<stdout>":
...@@ -39,13 +43,13 @@ def extended_open(filename, mode="r"): ...@@ -39,13 +43,13 @@ def extended_open(filename, mode="r"):
logger = logging.getLogger("aesara.compile.profiling") logger = logging.getLogger("aesara.compile.profiling")
aesara_imported_time = time.time() aesara_imported_time: float = time.time()
total_fct_exec_time = 0.0 total_fct_exec_time: float = 0.0
total_graph_opt_time = 0.0 total_graph_opt_time: float = 0.0
total_time_linker = 0.0 total_time_linker: float = 0.0
_atexit_print_list: List = [] _atexit_print_list: List["ProfileStats"] = []
_atexit_registered = False _atexit_registered: bool = False
def _atexit_print_fn(): def _atexit_print_fn():
...@@ -180,7 +184,6 @@ def register_profiler_printer(fct): ...@@ -180,7 +184,6 @@ def register_profiler_printer(fct):
class ProfileStats: class ProfileStats:
""" """
Object to store runtime and memory profiling information for all of Object to store runtime and memory profiling information for all of
Aesara's operations: compilation, optimization, execution. Aesara's operations: compilation, optimization, execution.
...@@ -215,72 +218,68 @@ class ProfileStats: ...@@ -215,72 +218,68 @@ class ProfileStats:
# #
show_sum: bool = True show_sum: bool = True
compile_time = 0.0 compile_time: float = 0.0
# Total time spent in body of orig_function, # Total time spent in body of orig_function,
# dominated by graph optimization and compilation of C # dominated by graph optimization and compilation of C
# #
fct_call_time = 0.0 fct_call_time: float = 0.0
# The total time spent in Function.__call__ # The total time spent in Function.__call__
# #
fct_callcount = 0 fct_callcount: int = 0
# Number of calls to Function.__call__ # Number of calls to Function.__call__
# #
vm_call_time = 0.0 vm_call_time: float = 0.0
# Total time spent in Function.vm.__call__ # Total time spent in Function.vm.__call__
# #
apply_time = None apply_time: Optional[Dict[Union["FunctionGraph", Variable], float]] = None
# dict from `(FunctionGraph, Variable)` to float runtime
#
apply_callcount = None apply_callcount: Optional[Dict[Union["FunctionGraph", Variable], int]] = None
# dict from `(FunctionGraph, Variable)` to number of executions
#
apply_cimpl = None apply_cimpl: Optional[Dict[Apply, bool]] = None
# dict from node -> bool (1 if c, 0 if py) # dict from node -> bool (1 if c, 0 if py)
# #
message = None message: Optional[str] = None
# pretty string to print in summary, to identify this output # pretty string to print in summary, to identify this output
# #
variable_shape: Dict = {} variable_shape: Dict[Variable, Any] = {}
# Variable -> shapes # Variable -> shapes
# #
variable_strides: Dict = {} variable_strides: Dict[Variable, Any] = {}
# Variable -> strides # Variable -> strides
# #
variable_offset: Dict = {} variable_offset: Dict[Variable, Any] = {}
# Variable -> offset # Variable -> offset
# #
optimizer_time = 0.0 optimizer_time: float = 0.0
# time spent optimizing graph (FunctionMaker.__init__) # time spent optimizing graph (FunctionMaker.__init__)
validate_time = 0.0 validate_time: float = 0.0
# time spent in fgraph.validate # time spent in fgraph.validate
# This is a subset of optimizer_time that is dominated by toposort() # This is a subset of optimizer_time that is dominated by toposort()
# when the destorymap feature is included. # when the destorymap feature is included.
linker_time = 0.0 linker_time: float = 0.0
# time spent linking graph (FunctionMaker.create) # time spent linking graph (FunctionMaker.create)
import_time = 0.0 import_time: float = 0.0
# time spent in importing compiled python module. # time spent in importing compiled python module.
linker_node_make_thunks = 0.0 linker_node_make_thunks: float = 0.0
linker_make_thunk_time: Dict = {} linker_make_thunk_time: Dict = {}
line_width = config.profiling__output_line_width line_width = config.profiling__output_line_width
nb_nodes = -1 nb_nodes: int = -1
# The number of nodes in the graph. We need the information separately in # The number of nodes in the graph. We need the information separately in
# case we print the profile when the function wasn't executed, or if there # case we print the profile when the function wasn't executed, or if there
# is a lazy operation in the graph. # is a lazy operation in the graph.
......
...@@ -230,6 +230,7 @@ class VM(ABC): ...@@ -230,6 +230,7 @@ class VM(ABC):
self.call_counts = [0] * len(nodes) self.call_counts = [0] * len(nodes)
self.call_times = [0] * len(nodes) self.call_times = [0] * len(nodes)
self.time_thunks = False self.time_thunks = False
self.storage_map: Optional[StorageMapType] = None
@abstractmethod @abstractmethod
def __call__(self): def __call__(self):
......
"""Pretty-printing (pprint()), the 'Print' Op, debugprint() and pydotprint(). """Functions for printing Aesara graphs."""
They all allow different way to print a graph or the result of an Op
in a graph(Print Op)
"""
import hashlib import hashlib
import logging import logging
import os import os
import sys import sys
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy from copy import copy
from functools import reduce, singledispatch from functools import reduce, singledispatch
from io import IOBase, StringIO from io import StringIO
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
import numpy as np import numpy as np
from typing_extensions import Literal
from aesara.compile import Function, SharedVariable from aesara.compile import Function, SharedVariable
from aesara.compile.io import In, Out from aesara.compile.io import In, Out
...@@ -27,6 +25,8 @@ from aesara.graph.op import HasInnerGraph, Op, StorageMapType ...@@ -27,6 +25,8 @@ from aesara.graph.op import HasInnerGraph, Op, StorageMapType
from aesara.graph.utils import Scratchpad from aesara.graph.utils import Scratchpad
IDTypesType = Literal["id", "int", "CHAR", "auto", ""]
pydot_imported = False pydot_imported = False
pydot_imported_msg = "" pydot_imported_msg = ""
try: try:
...@@ -105,61 +105,60 @@ def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str] ...@@ -105,61 +105,60 @@ def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str]
def debugprint( def debugprint(
obj: Union[ obj: Union[
Union[Variable, Apply, Function], List[Union[Variable, Apply, Function]] Union[Variable, Apply, Function, FunctionGraph],
Sequence[Union[Variable, Apply, Function, FunctionGraph]],
], ],
depth: int = -1, depth: int = -1,
print_type: bool = False, print_type: bool = False,
file: Optional[Union[str, IOBase]] = None, file: Optional[Union[Literal["str"], TextIO]] = None,
ids: str = "CHAR", id_type: IDTypesType = "CHAR",
stop_on_name: bool = False, stop_on_name: bool = False,
done: Optional[Dict[Apply, str]] = None, done: Optional[Dict[Union[Literal["output"], Variable, Apply], str]] = None,
print_storage: bool = False, print_storage: bool = False,
used_ids: Optional[Dict[Variable, str]] = None, used_ids: Optional[Dict[Union[Literal["output"], Variable, Apply], str]] = None,
print_op_info: bool = False, print_op_info: bool = False,
print_destroy_map: bool = False, print_destroy_map: bool = False,
print_view_map: bool = False, print_view_map: bool = False,
print_fgraph_inputs: bool = False, print_fgraph_inputs: bool = False,
) -> Union[str, IOBase]: ids: Optional[IDTypesType] = None,
r"""Print a computation graph as text to stdout or a file. ) -> Union[str, TextIO]:
r"""Print a graph as text.
Each line printed represents a Variable in the graph. Each line printed represents a `Variable` in a graph.
The indentation of lines corresponds to its depth in the symbolic graph. The indentation of lines corresponds to its depth in the symbolic graph.
The first part of the text identifies whether it is an input The first part of the text identifies whether it is an input or the output
(if a name or type is printed) or the output of some Apply (in which case of some `Apply` node.
the Op is printed). The second part of the text is an identifier of the `Variable`.
The second part of the text is an identifier of the Variable.
If print_type is True, we add a part containing the type of the Variable If a `Variable` is encountered multiple times in the depth-first search,
it is only printed recursively the first time. Later, just the `Variable`
If a Variable is encountered multiple times in the depth-first search,
it is only printed recursively the first time. Later, just the Variable
identifier is printed. identifier is printed.
If an Apply has multiple outputs, then a '.N' suffix will be appended If an `Apply` node has multiple outputs, then a ``.N`` suffix will be appended
to the Apply's identifier, to indicate which output a line corresponds to. to the `Apply` node's identifier, indicating to which output a line corresponds.
Parameters Parameters
---------- ----------
obj obj
The `Variable`, `Apply`, or `Function` instance to print (or a list The object(s) to be printed.
thereof).
depth depth
Print graph to this depth (``-1`` for unlimited). Print graph to this depth (``-1`` for unlimited).
print_type print_type
Whether to print the type of printed objects If ``True``, print the `Type`\s of each `Variable` in the graph.
file file
When `file` is extends `IOBase`, print to this file; when `file` is When `file` extends `TextIO`, print to it; when `file` is
equal to ``"str"``, return a string; when `file` is ``None``, print to equal to ``"str"``, return a string; when `file` is ``None``, print to
stdout. `sys.stdout`.
ids ids
Determines the type of identifier used for variables. Determines the type of identifier used for `Variable`\s:
- ``"id"``: print the python id value, - ``"id"``: print the python id value,
- ``"int"``: print integer character, - ``"int"``: print integer character,
- ``"CHAR"``: print capital character, - ``"CHAR"``: print capital character,
- ``"auto"``: print the ``auto_name`` value, - ``"auto"``: print the `Variable.auto_name` values,
- ``""``: don't print an identifier. - ``""``: don't print an identifier.
stop_on_name stop_on_name
When ``True``, if a node in the graph has a name, we don't print anything When ``True``, if a node in the graph has a name, we don't print
below it. anything below it.
done done
A ``dict`` where we store the ids of printed nodes. A ``dict`` where we store the ids of printed nodes.
Useful to have multiple call to `debugprint` share the same ids. Useful to have multiple call to `debugprint` share the same ids.
...@@ -188,12 +187,20 @@ def debugprint( ...@@ -188,12 +187,20 @@ def debugprint(
raise Exception("depth parameter must be an int") raise Exception("depth parameter must be an int")
if file == "str": if file == "str":
_file = StringIO() _file: Union[TextIO, StringIO] = StringIO()
elif file is None: elif file is None:
_file = sys.stdout _file = sys.stdout
else: else:
_file = file _file = file
if ids is not None:
warnings.warn(
"`ids` is deprecated; use `id_type` instead.",
DeprecationWarning,
stacklevel=2,
)
id_type = ids
if done is None: if done is None:
done = dict() done = dict()
...@@ -203,8 +210,8 @@ def debugprint( ...@@ -203,8 +210,8 @@ def debugprint(
inputs_to_print = [] inputs_to_print = []
outputs_to_print = [] outputs_to_print = []
profile_list: List[Optional[Any]] = [] profile_list: List[Optional[Any]] = []
order: List[Optional[List[Apply]]] = [] # Toposort topo_orders: List[Optional[List[Apply]]] = []
smap: List[Optional[StorageMapType]] = [] # storage_map storage_maps: List[Optional[StorageMapType]] = []
if isinstance(obj, (list, tuple, set)): if isinstance(obj, (list, tuple, set)):
lobj = obj lobj = obj
...@@ -215,43 +222,48 @@ def debugprint( ...@@ -215,43 +222,48 @@ def debugprint(
if isinstance(obj, Variable): if isinstance(obj, Variable):
outputs_to_print.append(obj) outputs_to_print.append(obj)
profile_list.append(None) profile_list.append(None)
smap.append(None) storage_maps.append(None)
order.append(None) topo_orders.append(None)
elif isinstance(obj, Apply): elif isinstance(obj, Apply):
outputs_to_print.extend(obj.outputs) outputs_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs]) profile_list.extend([None for item in obj.outputs])
smap.extend([None for item in obj.outputs]) storage_maps.extend([None for item in obj.outputs])
order.extend([None for item in obj.outputs]) topo_orders.extend([None for item in obj.outputs])
elif isinstance(obj, Function): elif isinstance(obj, Function):
if print_fgraph_inputs: if print_fgraph_inputs:
inputs_to_print.extend(obj.maker.fgraph.inputs) inputs_to_print.extend(obj.maker.fgraph.inputs)
outputs_to_print.extend(obj.maker.fgraph.outputs) outputs_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs]) profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs])
if print_storage: if print_storage:
smap.extend([obj.vm.storage_map for item in obj.maker.fgraph.outputs]) storage_maps.extend(
[obj.vm.storage_map for item in obj.maker.fgraph.outputs]
)
else: else:
smap.extend([None for item in obj.maker.fgraph.outputs]) storage_maps.extend([None for item in obj.maker.fgraph.outputs])
topo = obj.maker.fgraph.toposort() topo = obj.maker.fgraph.toposort()
order.extend([topo for item in obj.maker.fgraph.outputs]) topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
elif isinstance(obj, FunctionGraph): elif isinstance(obj, FunctionGraph):
if print_fgraph_inputs: if print_fgraph_inputs:
inputs_to_print.extend(obj.inputs) inputs_to_print.extend(obj.inputs)
outputs_to_print.extend(obj.outputs) outputs_to_print.extend(obj.outputs)
profile_list.extend([getattr(obj, "profile", None) for item in obj.outputs]) profile_list.extend([getattr(obj, "profile", None) for item in obj.outputs])
smap.extend([getattr(obj, "storage_map", None) for item in obj.outputs]) storage_maps.extend(
[getattr(obj, "storage_map", None) for item in obj.outputs]
)
topo = obj.toposort() topo = obj.toposort()
order.extend([topo for item in obj.outputs]) topo_orders.extend([topo for item in obj.outputs])
elif isinstance(obj, (int, float, np.ndarray)): elif isinstance(obj, (int, float, np.ndarray)):
print(obj, file=_file) print(obj, file=_file)
elif isinstance(obj, (In, Out)): elif isinstance(obj, (In, Out)):
outputs_to_print.append(obj.variable) outputs_to_print.append(obj.variable)
profile_list.append(None) profile_list.append(None)
smap.append(None) storage_maps.append(None)
order.append(None) topo_orders.append(None)
else: else:
raise TypeError(f"debugprint cannot print an object type {type(obj)}") raise TypeError(f"debugprint cannot print an object type {type(obj)}")
inner_graph_ops = [] inner_graph_vars: List[Variable] = []
if any(p for p in profile_list if p is not None and p.fct_callcount > 0): if any(p for p in profile_list if p is not None and p.fct_callcount > 0):
print( print(
""" """
...@@ -274,82 +286,84 @@ N.B.: ...@@ -274,82 +286,84 @@ N.B.:
file=_file, file=_file,
) )
op_information = {} op_information: Dict[Apply, Dict[Variable, str]] = {}
for r in inputs_to_print: for var in inputs_to_print:
_debugprint( _debugprint(
r, var,
prefix="-", prefix="-",
depth=depth, depth=depth,
done=done, done=done,
print_type=print_type, print_type=print_type,
file=_file, file=_file,
ids=ids, id_type=id_type,
inner_graph_ops=inner_graph_ops, inner_graph_ops=inner_graph_vars,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
used_ids=used_ids, used_ids=used_ids,
op_information=op_information, op_information=op_information,
parent_node=r.owner, parent_node=var.owner,
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map, print_destroy_map=print_destroy_map,
print_view_map=print_view_map, print_view_map=print_view_map,
) )
for r, p, s, o in zip(outputs_to_print, profile_list, smap, order): for var, profile, storage_map, topo_order in zip(
outputs_to_print, profile_list, storage_maps, topo_orders
):
if hasattr(r.owner, "op"): if hasattr(var.owner, "op"):
if isinstance(r.owner.op, HasInnerGraph) and r not in inner_graph_ops: if isinstance(var.owner.op, HasInnerGraph) and var not in inner_graph_vars:
inner_graph_ops.append(r) inner_graph_vars.append(var)
if print_op_info: if print_op_info:
op_information.update(op_debug_information(r.owner.op, r.owner)) op_information.update(op_debug_information(var.owner.op, var.owner))
_debugprint( _debugprint(
r, var,
depth=depth, depth=depth,
done=done, done=done,
print_type=print_type, print_type=print_type,
file=_file, file=_file,
order=o, topo_order=topo_order,
ids=ids, id_type=id_type,
inner_graph_ops=inner_graph_ops, inner_graph_ops=inner_graph_vars,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
profile=p, profile=profile,
smap=s, storage_map=storage_map,
used_ids=used_ids, used_ids=used_ids,
op_information=op_information, op_information=op_information,
parent_node=r.owner, parent_node=var.owner,
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map, print_destroy_map=print_destroy_map,
print_view_map=print_view_map, print_view_map=print_view_map,
) )
if len(inner_graph_ops) > 0: if len(inner_graph_vars) > 0:
print("", file=_file) print("", file=_file)
new_prefix = " >" new_prefix = " >"
new_prefix_child = " >" new_prefix_child = " >"
print("Inner graphs:", file=_file) print("Inner graphs:", file=_file)
for s in inner_graph_ops: for ig_var in inner_graph_vars:
# This is a work-around to maintain backward compatibility # This is a work-around to maintain backward compatibility
# (e.g. to only print inner graphs that have been compiled through # (e.g. to only print inner graphs that have been compiled through
# a call to `Op.prepare_node`) # a call to `Op.prepare_node`)
inner_fn = getattr(s.owner.op, "_fn", None) inner_fn = getattr(ig_var.owner.op, "_fn", None)
if inner_fn: if inner_fn:
# If the op was compiled, print the optimized version. # If the op was compiled, print the optimized version.
inner_inputs = inner_fn.maker.fgraph.inputs inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs inner_outputs = inner_fn.maker.fgraph.outputs
else: else:
inner_inputs = s.owner.op.inner_inputs inner_inputs = ig_var.owner.op.inner_inputs
inner_outputs = s.owner.op.inner_outputs inner_outputs = ig_var.owner.op.inner_outputs
outer_inputs = s.owner.inputs outer_inputs = ig_var.owner.inputs
if hasattr(s.owner.op, "get_oinp_iinp_iout_oout_mappings"): if hasattr(ig_var.owner.op, "get_oinp_iinp_iout_oout_mappings"):
inner_to_outer_inputs = { inner_to_outer_inputs = {
inner_inputs[i]: outer_inputs[o] inner_inputs[i]: outer_inputs[o]
for i, o in s.owner.op.get_oinp_iinp_iout_oout_mappings()[ for i, o in ig_var.owner.op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp" "outer_inp_from_inner_inp"
].items() ].items()
} }
...@@ -357,23 +371,25 @@ N.B.: ...@@ -357,23 +371,25 @@ N.B.:
inner_to_outer_inputs = None inner_to_outer_inputs = None
if print_op_info: if print_op_info:
op_information.update(op_debug_information(s.owner.op, s.owner)) op_information.update(
op_debug_information(ig_var.owner.op, ig_var.owner)
)
print("", file=_file) print("", file=_file)
_debugprint( _debugprint(
s, ig_var,
depth=depth, depth=depth,
done=done, done=done,
print_type=print_type, print_type=print_type,
file=_file, file=_file,
ids=ids, id_type=id_type,
inner_graph_ops=inner_graph_ops, inner_graph_ops=inner_graph_vars,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
inner_to_outer_inputs=inner_to_outer_inputs, inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids, used_ids=used_ids,
op_information=op_information, op_information=op_information,
parent_node=s.owner, parent_node=ig_var.owner,
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map, print_destroy_map=print_destroy_map,
print_view_map=print_view_map, print_view_map=print_view_map,
...@@ -382,23 +398,23 @@ N.B.: ...@@ -382,23 +398,23 @@ N.B.:
if print_fgraph_inputs: if print_fgraph_inputs:
for inp in inner_inputs: for inp in inner_inputs:
_debugprint( _debugprint(
r=inp, inp,
prefix="-", prefix="-",
depth=depth, depth=depth,
done=done, done=done,
print_type=print_type, print_type=print_type,
file=_file, file=_file,
ids=ids, id_type=id_type,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
inner_graph_ops=inner_graph_ops, inner_graph_ops=inner_graph_vars,
inner_to_outer_inputs=inner_to_outer_inputs, inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids, used_ids=used_ids,
op_information=op_information, op_information=op_information,
parent_node=s.owner, parent_node=ig_var.owner,
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map, print_destroy_map=print_destroy_map,
print_view_map=print_view_map, print_view_map=print_view_map,
inner_graph_node=s.owner, inner_graph_node=ig_var.owner,
) )
inner_to_outer_inputs = None inner_to_outer_inputs = None
...@@ -406,34 +422,35 @@ N.B.: ...@@ -406,34 +422,35 @@ N.B.:
if ( if (
isinstance(getattr(out.owner, "op", None), HasInnerGraph) isinstance(getattr(out.owner, "op", None), HasInnerGraph)
and out not in inner_graph_ops and out not in inner_graph_vars
): ):
inner_graph_ops.append(out) inner_graph_vars.append(out)
_debugprint( _debugprint(
r=out, out,
prefix=new_prefix, prefix=new_prefix,
depth=depth, depth=depth,
done=done, done=done,
print_type=print_type, print_type=print_type,
file=_file, file=_file,
ids=ids, id_type=id_type,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, prefix_child=new_prefix_child,
inner_graph_ops=inner_graph_ops, inner_graph_ops=inner_graph_vars,
inner_to_outer_inputs=inner_to_outer_inputs, inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids, used_ids=used_ids,
op_information=op_information, op_information=op_information,
parent_node=s.owner, parent_node=ig_var.owner,
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map, print_destroy_map=print_destroy_map,
print_view_map=print_view_map, print_view_map=print_view_map,
inner_graph_node=s.owner, inner_graph_node=ig_var.owner,
) )
if file is _file: if file is _file:
return file return file
elif file == "str": elif file == "str":
assert isinstance(_file, StringIO)
return _file.getvalue() return _file.getvalue()
else: else:
_file.flush() _file.flush()
...@@ -441,96 +458,87 @@ N.B.: ...@@ -441,96 +458,87 @@ N.B.:
def _debugprint( def _debugprint(
r: Variable, var: Variable,
prefix: str = "", prefix: str = "",
depth: int = -1, depth: int = -1,
done: Optional[Dict[Apply, str]] = None, done: Optional[Dict[Union[Literal["output"], Variable, Apply], str]] = None,
print_type: bool = False, print_type: bool = False,
file: IOBase = sys.stdout, file: TextIO = sys.stdout,
print_destroy_map: bool = False, print_destroy_map: bool = False,
print_view_map: bool = False, print_view_map: bool = False,
order: Optional[List[Variable]] = None, topo_order: Optional[Sequence[Apply]] = None,
ids: str = "CHAR", id_type: IDTypesType = "CHAR",
stop_on_name: bool = False, stop_on_name: bool = False,
prefix_child: Optional[str] = None, prefix_child: Optional[str] = None,
inner_graph_ops: Optional[List[Variable]] = None, inner_graph_ops: Optional[List[Variable]] = None,
profile: Optional[ProfileStats] = None, profile: Optional[ProfileStats] = None,
inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None, inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None,
smap: Optional[StorageMapType] = None, storage_map: Optional[StorageMapType] = None,
used_ids: Optional[Dict[Variable, str]] = None, used_ids: Optional[Dict[Union[Literal["output"], Variable, Apply], str]] = None,
op_information: Optional[Dict[Apply, Dict[Variable, str]]] = None, op_information: Optional[Dict[Apply, Dict[Variable, str]]] = None,
parent_node: Optional[Apply] = None, parent_node: Optional[Apply] = None,
print_op_info: bool = False, print_op_info: bool = False,
inner_graph_node: Optional[Apply] = None, inner_graph_node: Optional[Apply] = None,
) -> IOBase: ) -> TextIO:
r"""Print the graph leading to `r`. r"""Print the graph represented by `var`.
Parameters Parameters
---------- ----------
r var
A `Variable` instance. A `Variable` instance.
prefix prefix
Prefix to each line (typically some number of spaces). Prefix to each line (typically some number of spaces).
depth depth
Print graph to this depth (``-1`` for unlimited). Print graph to this depth (``-1`` for unlimited).
done done
A ``dict`` of `Apply` instances that have already been printed and See `debugprint`.
their associated printed ids.
Internal. Used to pass information when recursing.
print_type print_type
Whether to print the `Variable`'s type. See `debugprint`.
file file
File-like object to which to print. File-like object to which to print.
print_destroy_map print_destroy_map
Whether to print `Op` ``destroy_map``\s. Whether to print the `Variable`'s type.
print_view_map print_view_map
Whether to print `Op` ``view_map``\s. Whether to print `Op` ``destroy_map``\s.
order topo_order
If not empty will print the index in the toposort. If not empty will print the index in the toposort.
ids id_type
Determines the type of identifier used for variables. See `debugprint`.
- ``"id"``: print the python id value,
- ``"int"``: print integer character,
- ``"CHAR"``: print capital character,
- ``"auto"``: print the ``auto_name`` value,
- ``""``: don't print an identifier.
stop_on_name stop_on_name
When ``True``, if a node in the graph has a name, we don't print anything Whether to print `Op` ``view_map``\s.
below it.
inner_graph_ops inner_graph_ops
A list of `Op`\s with inner graphs. A list of `Op`\s with inner graphs.
inner_to_outer_inputs inner_to_outer_inputs
A dictionary mapping an `Op`'s inner-inputs to its outer-inputs. A dictionary mapping an `Op`'s inner-inputs to its outer-inputs.
smap storage_map
``None`` or the ``storage_map`` when printing an Aesara function. ``None`` or the storage map (e.g. when printing an Aesara function).
used_ids used_ids
A map between nodes and their printed ids. See `debugprint`.
It wasn't always printed, but at least a reference to it was printed.
Internal. Used to pass information when recursing.
op_information op_information
Extra `Op`-level information to be added to variable print-outs. Extra `Op`-level information to be added to variable print-outs.
parent_node parent_node
The parent node of `r`. The parent node of `var`.
print_op_info print_op_info
Print extra information provided by the relevant `Op`\s. For example, See `debugprint`.
print the tap information for `Scan` inputs and outputs.
inner_graph_node inner_graph_node
The inner-graph node in which `r` is contained. The inner-graph node in which `var` is contained.
""" """
if depth == 0: if depth == 0:
return file return file
if order is None: if topo_order is None:
order = [] topo_order = []
if done is None: if done is None:
done = dict() _done = dict()
else:
_done = done
if inner_graph_ops is None: if inner_graph_ops is None:
inner_graph_ops = [] inner_graph_ops = []
if print_type: if print_type:
type_str = f" <{r.type}>" type_str = f" <{var.type}>"
else: else:
type_str = "" type_str = ""
...@@ -538,92 +546,100 @@ def _debugprint( ...@@ -538,92 +546,100 @@ def _debugprint(
prefix_child = prefix prefix_child = prefix
if used_ids is None: if used_ids is None:
used_ids = dict() _used_ids = dict()
else:
_used_ids = used_ids
if op_information is None: if op_information is None:
op_information = {} op_information = {}
def get_id_str(obj, get_printed=True) -> str: def get_id_str(
obj: Union[Literal["output"], Apply, Variable], get_printed: bool = True
) -> str:
id_str: str = "" id_str: str = ""
if obj in used_ids: if obj in _used_ids:
id_str = used_ids[obj] id_str = _used_ids[obj]
elif obj == "output": elif obj == "output":
id_str = "output" id_str = "output"
elif ids == "id": elif id_type == "id":
id_str = f"[id {id(r)}]" id_str = f"[id {id(var)}]"
elif ids == "int": elif id_type == "int":
id_str = f"[id {len(used_ids)}]" id_str = f"[id {len(_used_ids)}]"
elif ids == "CHAR": elif id_type == "CHAR":
id_str = f"[id {char_from_number(len(used_ids))}]" id_str = f"[id {char_from_number(len(_used_ids))}]"
elif ids == "auto": elif id_type == "auto":
id_str = f"[id {r.auto_name}]" id_str = f"[id {var.auto_name}]"
elif ids == "": elif id_type == "":
id_str = "" id_str = ""
if get_printed: if get_printed:
done[obj] = id_str _done[obj] = id_str
used_ids[obj] = id_str _used_ids[obj] = id_str
return id_str return id_str
if hasattr(r.owner, "op"): if var.owner:
# This variable is the output of a computation, so just print out the # This variable is the output of a computation, so just print out the
# `Apply` node # `Apply` node
a = r.owner node = var.owner
r_name = getattr(r, "name", "") var_name = getattr(var, "name", "")
if r_name is None: if var_name is None:
r_name = "" var_name = ""
if r_name: if var_name:
r_name = f" '{r_name}'" var_name = f" '{var_name}'"
if print_destroy_map and r.owner.op.destroy_map: if print_destroy_map and node.op.destroy_map:
destroy_map_str = f" d={r.owner.op.destroy_map}" destroy_map_str = f" d={node.op.destroy_map}"
else: else:
destroy_map_str = "" destroy_map_str = ""
if print_view_map and r.owner.op.view_map: if print_view_map and node.op.view_map:
view_map_str = f" v={r.owner.op.view_map}" view_map_str = f" v={node.op.view_map}"
else: else:
view_map_str = "" view_map_str = ""
if order: if topo_order:
o = f" {order.index(r.owner)}" o = f" {topo_order.index(node)}"
else: else:
o = "" o = ""
already_done = a in done already_done = node in _done
id_str = get_id_str(a) id_str = get_id_str(node)
if len(a.outputs) == 1: if len(node.outputs) == 1:
idx = "" output_idx = ""
else: else:
idx = f".{a.outputs.index(r)}" output_idx = f".{node.outputs.index(var)}"
if id_str: if id_str:
id_str = f" {id_str}" id_str = f" {id_str}"
if smap and a.outputs[0] in smap: if storage_map and node.outputs[0] in storage_map:
data = f" {smap[a.outputs[0]]}" data = f" {storage_map[node.outputs[0]]}"
else: else:
data = "" data = ""
var_output = f"{prefix}{a.op}{idx}{id_str}{type_str}{r_name}{destroy_map_str}{view_map_str}{o}{data}" var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
if print_op_info and r.owner not in op_information: if print_op_info and node not in op_information:
op_information.update(op_debug_information(r.owner.op, r.owner)) op_information.update(op_debug_information(node.op, node))
node_info = op_information.get(parent_node) or op_information.get(r.owner) node_info = (
if node_info and r in node_info: parent_node and op_information.get(parent_node)
var_output = f"{var_output} ({node_info[r]})" ) or op_information.get(node)
if node_info and var in node_info:
var_output = f"{var_output} ({node_info[var]})"
if profile is None or a not in profile.apply_time: if profile is None:
print(var_output, file=file) print(var_output, file=file)
else: elif profile.apply_time and node not in profile.apply_time:
op_time = profile.apply_time[a] print(var_output, file=file)
elif profile.apply_time and node in profile.apply_time:
op_time = profile.apply_time[node]
op_time_percent = (op_time / profile.fct_call_time) * 100 op_time_percent = (op_time / profile.fct_call_time) * 100
tot_time_dict = profile.compute_total_times() tot_time_dict = profile.compute_total_times()
tot_time = tot_time_dict[a] tot_time = tot_time_dict[node]
tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100 tot_time_percent = (tot_time_dict[node] / profile.fct_call_time) * 100
print( print(
"%s --> %8.2es %4.1f%% %8.2es %4.1f%%" "%s --> %8.2es %4.1f%% %8.2es %4.1f%%"
...@@ -638,40 +654,40 @@ def _debugprint( ...@@ -638,40 +654,40 @@ def _debugprint(
) )
if not already_done and ( if not already_done and (
not stop_on_name or not (hasattr(r, "name") and r.name is not None) not stop_on_name or not (hasattr(var, "name") and var.name is not None)
): ):
new_prefix = prefix_child + " |" new_prefix = prefix_child + " |"
new_prefix_child = prefix_child + " |" new_prefix_child = prefix_child + " |"
for idx, i in enumerate(a.inputs): for in_idx, in_var in enumerate(node.inputs):
if idx == len(a.inputs) - 1: if in_idx == len(node.inputs) - 1:
new_prefix_child = prefix_child + " " new_prefix_child = prefix_child + " "
if hasattr(i, "owner") and hasattr(i.owner, "op"): if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if ( if (
isinstance(i.owner.op, HasInnerGraph) isinstance(in_var.owner.op, HasInnerGraph)
and i not in inner_graph_ops and in_var not in inner_graph_ops
): ):
inner_graph_ops.append(i) inner_graph_ops.append(in_var)
_debugprint( _debugprint(
i, in_var,
new_prefix, new_prefix,
depth=depth - 1, depth=depth - 1,
done=done, done=_done,
print_type=print_type, print_type=print_type,
file=file, file=file,
order=order, topo_order=topo_order,
ids=ids, id_type=id_type,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, prefix_child=new_prefix_child,
inner_graph_ops=inner_graph_ops, inner_graph_ops=inner_graph_ops,
profile=profile, profile=profile,
inner_to_outer_inputs=inner_to_outer_inputs, inner_to_outer_inputs=inner_to_outer_inputs,
smap=smap, storage_map=storage_map,
used_ids=used_ids, used_ids=_used_ids,
op_information=op_information, op_information=op_information,
parent_node=a, parent_node=node,
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map, print_destroy_map=print_destroy_map,
print_view_map=print_view_map, print_view_map=print_view_map,
...@@ -679,38 +695,38 @@ def _debugprint( ...@@ -679,38 +695,38 @@ def _debugprint(
) )
else: else:
id_str = get_id_str(r) id_str = get_id_str(var)
if id_str: if id_str:
id_str = f" {id_str}" id_str = f" {id_str}"
if smap and r in smap: if storage_map and var in storage_map:
data = f" {smap[r]}" data = f" {storage_map[var]}"
else: else:
data = "" data = ""
var_output = f"{prefix}{r}{id_str}{type_str}{data}" var_output = f"{prefix}{var}{id_str}{type_str}{data}"
if print_op_info and r.owner and r.owner not in op_information: if print_op_info and var.owner and var.owner not in op_information:
op_information.update(op_debug_information(r.owner.op, r.owner)) op_information.update(op_debug_information(var.owner.op, var.owner))
if inner_to_outer_inputs is not None and r in inner_to_outer_inputs: if inner_to_outer_inputs is not None and var in inner_to_outer_inputs:
outer_r = inner_to_outer_inputs[r] outer_var = inner_to_outer_inputs[var]
if outer_r.owner: if outer_var.owner:
outer_id_str = get_id_str(outer_r.owner) outer_id_str = get_id_str(outer_var.owner)
else: else:
outer_id_str = get_id_str(outer_r) outer_id_str = get_id_str(outer_var)
var_output = f"{var_output} -> {outer_id_str}" var_output = f"{var_output} -> {outer_id_str}"
# TODO: This entire approach will only print `Op` info for two levels # TODO: This entire approach will only print `Op` info for two levels
# of nesting. # of nesting.
for node in dict.fromkeys([inner_graph_node, parent_node, r.owner]): for node in dict.fromkeys([inner_graph_node, parent_node, var.owner]):
node_info = op_information.get(node) node_info = op_information.get(node)
if node_info and r in node_info: if node_info and var in node_info:
var_output = f"{var_output} ({node_info[r]})" var_output = f"{var_output} ({node_info[var]})"
print(var_output, file=file) print(var_output, file=file)
...@@ -1827,18 +1843,20 @@ def hex_digest(x): ...@@ -1827,18 +1843,20 @@ def hex_digest(x):
def get_node_by_id( def get_node_by_id(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR" graphs: Union[Variable, Sequence[Variable], Function, FunctionGraph],
) -> Optional[Union[Variable, Apply]]: target_var_id: str,
id_types: IDTypesType = "CHAR",
) -> Optional[Union[Literal["output"], Variable, Apply]]:
r"""Get `Apply` nodes or `Variable`\s in a graph using their `debugprint` IDs. r"""Get `Apply` nodes or `Variable`\s in a graph using their `debugprint` IDs.
Parameters Parameters
---------- ----------
graphs: graphs
The graph, or graphs, to search. The graph, or graphs, to search.
target_var_id: target_var_id
The name to search for. The name to search for.
ids: id_types
The ID scheme to use (see `debugprint.`). The ID scheme to use (see `debugprint`).
Returns Returns
------- -------
...@@ -1847,12 +1865,9 @@ def get_node_by_id( ...@@ -1847,12 +1865,9 @@ def get_node_by_id(
""" """
from aesara.printing import debugprint from aesara.printing import debugprint
if isinstance(graphs, Variable): used_ids: Dict[Union[Literal["output"], Variable, Apply], str] = {}
graphs = (graphs,)
used_ids: Dict[Variable, str] = {}
_ = debugprint(graphs, file="str", used_ids=used_ids, ids=ids) _ = debugprint(graphs, file="str", used_ids=used_ids, id_type=id_types)
id_to_node = {v: k for k, v in used_ids.items()} id_to_node = {v: k for k, v in used_ids.items()}
......
...@@ -227,10 +227,6 @@ check_untyped_defs = False ...@@ -227,10 +227,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.printing]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.tensor.nnet.conv3d2d] [mypy-aesara.tensor.nnet.conv3d2d]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
...@@ -133,9 +133,8 @@ def test_debugprint(): ...@@ -133,9 +133,8 @@ def test_debugprint():
s = StringIO() s = StringIO()
debugprint(G, file=s) debugprint(G, file=s)
# test ids=int
s = StringIO() s = StringIO()
debugprint(G, file=s, ids="int") debugprint(G, file=s, id_type="int")
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = ( reference = (
...@@ -155,9 +154,8 @@ def test_debugprint(): ...@@ -155,9 +154,8 @@ def test_debugprint():
assert s == reference assert s == reference
# test ids=CHAR
s = StringIO() s = StringIO()
debugprint(G, file=s, ids="CHAR") debugprint(G, file=s, id_type="CHAR")
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = ( reference = (
...@@ -177,9 +175,8 @@ def test_debugprint(): ...@@ -177,9 +175,8 @@ def test_debugprint():
assert s == reference assert s == reference
# test ids=CHAR, stop_on_name=True
s = StringIO() s = StringIO()
debugprint(G, file=s, ids="CHAR", stop_on_name=True) debugprint(G, file=s, id_type="CHAR", stop_on_name=True)
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = ( reference = (
...@@ -197,9 +194,8 @@ def test_debugprint(): ...@@ -197,9 +194,8 @@ def test_debugprint():
assert s == reference assert s == reference
# test ids=
s = StringIO() s = StringIO()
debugprint(G, file=s, ids="") debugprint(G, file=s, id_type="")
s = s.getvalue() s = s.getvalue()
# The additional white space are needed! # The additional white space are needed!
reference = ( reference = (
...@@ -221,7 +217,7 @@ def test_debugprint(): ...@@ -221,7 +217,7 @@ def test_debugprint():
# test print_storage=True # test print_storage=True
s = StringIO() s = StringIO()
debugprint(g, file=s, ids="", print_storage=True) debugprint(g, file=s, id_type="", print_storage=True)
s = s.getvalue() s = s.getvalue()
reference = ( reference = (
"\n".join( "\n".join(
...@@ -246,7 +242,7 @@ def test_debugprint(): ...@@ -246,7 +242,7 @@ def test_debugprint():
debugprint( debugprint(
aesara.function([A, B, D, J], A + (B.dot(J) - D), mode="FAST_RUN"), aesara.function([A, B, D, J], A + (B.dot(J) - D), mode="FAST_RUN"),
file=s, file=s,
ids="", id_type="",
print_destroy_map=True, print_destroy_map=True,
print_view_map=True, print_view_map=True,
) )
...@@ -270,7 +266,7 @@ def test_debugprint(): ...@@ -270,7 +266,7 @@ def test_debugprint():
] ]
def test_debugprint_ids(): def test_debugprint_id_type():
a_at = dvector() a_at = dvector()
b_at = dmatrix() b_at = dmatrix()
...@@ -278,7 +274,7 @@ def test_debugprint_ids(): ...@@ -278,7 +274,7 @@ def test_debugprint_ids():
e_at = d_at + a_at e_at = d_at + a_at
s = StringIO() s = StringIO()
debugprint(e_at, ids="auto", file=s) debugprint(e_at, id_type="auto", file=s)
s = s.getvalue() s = s.getvalue()
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}] exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论