提交 bb40791b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix the type hints in aesara.printing

上级 9d360389
......@@ -16,16 +16,20 @@ import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import numpy as np
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.basic import Apply, Constant, Variable
from aesara.link.utils import get_destroy_dependencies
if TYPE_CHECKING:
from aesara.graph.fg import FunctionGraph
@contextmanager
def extended_open(filename, mode="r"):
if filename == "<stdout>":
......@@ -39,13 +43,13 @@ def extended_open(filename, mode="r"):
logger = logging.getLogger("aesara.compile.profiling")
aesara_imported_time = time.time()
total_fct_exec_time = 0.0
total_graph_opt_time = 0.0
total_time_linker = 0.0
aesara_imported_time: float = time.time()
total_fct_exec_time: float = 0.0
total_graph_opt_time: float = 0.0
total_time_linker: float = 0.0
_atexit_print_list: List = []
_atexit_registered = False
_atexit_print_list: List["ProfileStats"] = []
_atexit_registered: bool = False
def _atexit_print_fn():
......@@ -180,7 +184,6 @@ def register_profiler_printer(fct):
class ProfileStats:
"""
Object to store runtime and memory profiling information for all of
Aesara's operations: compilation, optimization, execution.
......@@ -215,72 +218,68 @@ class ProfileStats:
#
show_sum: bool = True
compile_time = 0.0
compile_time: float = 0.0
# Total time spent in body of orig_function,
# dominated by graph optimization and compilation of C
#
fct_call_time = 0.0
fct_call_time: float = 0.0
# The total time spent in Function.__call__
#
fct_callcount = 0
fct_callcount: int = 0
# Number of calls to Function.__call__
#
vm_call_time = 0.0
vm_call_time: float = 0.0
# Total time spent in Function.vm.__call__
#
apply_time = None
# dict from `(FunctionGraph, Variable)` to float runtime
#
apply_time: Optional[Dict[Union["FunctionGraph", Variable], float]] = None
apply_callcount = None
# dict from `(FunctionGraph, Variable)` to number of executions
#
apply_callcount: Optional[Dict[Union["FunctionGraph", Variable], int]] = None
apply_cimpl = None
apply_cimpl: Optional[Dict[Apply, bool]] = None
# dict from node -> bool (1 if c, 0 if py)
#
message = None
message: Optional[str] = None
# pretty string to print in summary, to identify this output
#
variable_shape: Dict = {}
variable_shape: Dict[Variable, Any] = {}
# Variable -> shapes
#
variable_strides: Dict = {}
variable_strides: Dict[Variable, Any] = {}
# Variable -> strides
#
variable_offset: Dict = {}
variable_offset: Dict[Variable, Any] = {}
# Variable -> offset
#
optimizer_time = 0.0
optimizer_time: float = 0.0
# time spent optimizing graph (FunctionMaker.__init__)
validate_time = 0.0
validate_time: float = 0.0
# time spent in fgraph.validate
# This is a subset of optimizer_time that is dominated by toposort()
# when the destorymap feature is included.
linker_time = 0.0
linker_time: float = 0.0
# time spent linking graph (FunctionMaker.create)
import_time = 0.0
import_time: float = 0.0
# time spent in importing compiled python module.
linker_node_make_thunks = 0.0
linker_node_make_thunks: float = 0.0
linker_make_thunk_time: Dict = {}
line_width = config.profiling__output_line_width
nb_nodes = -1
nb_nodes: int = -1
# The number of nodes in the graph. We need the information separately in
# case we print the profile when the function wasn't executed, or if there
# is a lazy operation in the graph.
......
......@@ -230,6 +230,7 @@ class VM(ABC):
self.call_counts = [0] * len(nodes)
self.call_times = [0] * len(nodes)
self.time_thunks = False
self.storage_map: Optional[StorageMapType] = None
@abstractmethod
def __call__(self):
......
"""Pretty-printing (pprint()), the 'Print' Op, debugprint() and pydotprint().
They all allow different way to print a graph or the result of an Op
in a graph(Print Op)
"""
"""Functions for printing Aesara graphs."""
import hashlib
import logging
import os
import sys
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from copy import copy
from functools import reduce, singledispatch
from io import IOBase, StringIO
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
import numpy as np
from typing_extensions import Literal
from aesara.compile import Function, SharedVariable
from aesara.compile.io import In, Out
......@@ -27,6 +25,8 @@ from aesara.graph.op import HasInnerGraph, Op, StorageMapType
from aesara.graph.utils import Scratchpad
IDTypesType = Literal["id", "int", "CHAR", "auto", ""]
pydot_imported = False
pydot_imported_msg = ""
try:
......@@ -105,61 +105,60 @@ def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str]
def debugprint(
obj: Union[
Union[Variable, Apply, Function], List[Union[Variable, Apply, Function]]
Union[Variable, Apply, Function, FunctionGraph],
Sequence[Union[Variable, Apply, Function, FunctionGraph]],
],
depth: int = -1,
print_type: bool = False,
file: Optional[Union[str, IOBase]] = None,
ids: str = "CHAR",
file: Optional[Union[Literal["str"], TextIO]] = None,
id_type: IDTypesType = "CHAR",
stop_on_name: bool = False,
done: Optional[Dict[Apply, str]] = None,
done: Optional[Dict[Union[Literal["output"], Variable, Apply], str]] = None,
print_storage: bool = False,
used_ids: Optional[Dict[Variable, str]] = None,
used_ids: Optional[Dict[Union[Literal["output"], Variable, Apply], str]] = None,
print_op_info: bool = False,
print_destroy_map: bool = False,
print_view_map: bool = False,
print_fgraph_inputs: bool = False,
) -> Union[str, IOBase]:
r"""Print a computation graph as text to stdout or a file.
ids: Optional[IDTypesType] = None,
) -> Union[str, TextIO]:
r"""Print a graph as text.
Each line printed represents a Variable in the graph.
Each line printed represents a `Variable` in a graph.
The indentation of lines corresponds to its depth in the symbolic graph.
The first part of the text identifies whether it is an input
(if a name or type is printed) or the output of some Apply (in which case
the Op is printed).
The second part of the text is an identifier of the Variable.
If print_type is True, we add a part containing the type of the Variable
If a Variable is encountered multiple times in the depth-first search,
it is only printed recursively the first time. Later, just the Variable
The first part of the text identifies whether it is an input or the output
of some `Apply` node.
The second part of the text is an identifier of the `Variable`.
If a `Variable` is encountered multiple times in the depth-first search,
it is only printed recursively the first time. Later, just the `Variable`
identifier is printed.
If an Apply has multiple outputs, then a '.N' suffix will be appended
to the Apply's identifier, to indicate which output a line corresponds to.
If an `Apply` node has multiple outputs, then a ``.N`` suffix will be appended
to the `Apply` node's identifier, indicating to which output a line corresponds.
Parameters
----------
obj
The `Variable`, `Apply`, or `Function` instance to print (or a list
thereof).
The object(s) to be printed.
depth
Print graph to this depth (``-1`` for unlimited).
print_type
Whether to print the type of printed objects
If ``True``, print the `Type`\s of each `Variable` in the graph.
file
When `file` is extends `IOBase`, print to this file; when `file` is
When `file` extends `TextIO`, print to it; when `file` is
equal to ``"str"``, return a string; when `file` is ``None``, print to
stdout.
`sys.stdout`.
ids
Determines the type of identifier used for variables.
Determines the type of identifier used for `Variable`\s:
- ``"id"``: print the python id value,
- ``"int"``: print integer character,
- ``"CHAR"``: print capital character,
- ``"auto"``: print the ``auto_name`` value,
- ``"auto"``: print the `Variable.auto_name` values,
- ``""``: don't print an identifier.
stop_on_name
When ``True``, if a node in the graph has a name, we don't print anything
below it.
When ``True``, if a node in the graph has a name, we don't print
anything below it.
done
A ``dict`` where we store the ids of printed nodes.
Useful to have multiple call to `debugprint` share the same ids.
......@@ -188,12 +187,20 @@ def debugprint(
raise Exception("depth parameter must be an int")
if file == "str":
_file = StringIO()
_file: Union[TextIO, StringIO] = StringIO()
elif file is None:
_file = sys.stdout
else:
_file = file
if ids is not None:
warnings.warn(
"`ids` is deprecated; use `id_type` instead.",
DeprecationWarning,
stacklevel=2,
)
id_type = ids
if done is None:
done = dict()
......@@ -203,8 +210,8 @@ def debugprint(
inputs_to_print = []
outputs_to_print = []
profile_list: List[Optional[Any]] = []
order: List[Optional[List[Apply]]] = [] # Toposort
smap: List[Optional[StorageMapType]] = [] # storage_map
topo_orders: List[Optional[List[Apply]]] = []
storage_maps: List[Optional[StorageMapType]] = []
if isinstance(obj, (list, tuple, set)):
lobj = obj
......@@ -215,43 +222,48 @@ def debugprint(
if isinstance(obj, Variable):
outputs_to_print.append(obj)
profile_list.append(None)
smap.append(None)
order.append(None)
storage_maps.append(None)
topo_orders.append(None)
elif isinstance(obj, Apply):
outputs_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs])
smap.extend([None for item in obj.outputs])
order.extend([None for item in obj.outputs])
storage_maps.extend([None for item in obj.outputs])
topo_orders.extend([None for item in obj.outputs])
elif isinstance(obj, Function):
if print_fgraph_inputs:
inputs_to_print.extend(obj.maker.fgraph.inputs)
outputs_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs])
if print_storage:
smap.extend([obj.vm.storage_map for item in obj.maker.fgraph.outputs])
storage_maps.extend(
[obj.vm.storage_map for item in obj.maker.fgraph.outputs]
)
else:
smap.extend([None for item in obj.maker.fgraph.outputs])
storage_maps.extend([None for item in obj.maker.fgraph.outputs])
topo = obj.maker.fgraph.toposort()
order.extend([topo for item in obj.maker.fgraph.outputs])
topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
elif isinstance(obj, FunctionGraph):
if print_fgraph_inputs:
inputs_to_print.extend(obj.inputs)
outputs_to_print.extend(obj.outputs)
profile_list.extend([getattr(obj, "profile", None) for item in obj.outputs])
smap.extend([getattr(obj, "storage_map", None) for item in obj.outputs])
storage_maps.extend(
[getattr(obj, "storage_map", None) for item in obj.outputs]
)
topo = obj.toposort()
order.extend([topo for item in obj.outputs])
topo_orders.extend([topo for item in obj.outputs])
elif isinstance(obj, (int, float, np.ndarray)):
print(obj, file=_file)
elif isinstance(obj, (In, Out)):
outputs_to_print.append(obj.variable)
profile_list.append(None)
smap.append(None)
order.append(None)
storage_maps.append(None)
topo_orders.append(None)
else:
raise TypeError(f"debugprint cannot print an object type {type(obj)}")
inner_graph_ops = []
inner_graph_vars: List[Variable] = []
if any(p for p in profile_list if p is not None and p.fct_callcount > 0):
print(
"""
......@@ -274,82 +286,84 @@ N.B.:
file=_file,
)
op_information = {}
op_information: Dict[Apply, Dict[Variable, str]] = {}
for r in inputs_to_print:
for var in inputs_to_print:
_debugprint(
r,
var,
prefix="-",
depth=depth,
done=done,
print_type=print_type,
file=_file,
ids=ids,
inner_graph_ops=inner_graph_ops,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
stop_on_name=stop_on_name,
used_ids=used_ids,
op_information=op_information,
parent_node=r.owner,
parent_node=var.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
)
for r, p, s, o in zip(outputs_to_print, profile_list, smap, order):
for var, profile, storage_map, topo_order in zip(
outputs_to_print, profile_list, storage_maps, topo_orders
):
if hasattr(r.owner, "op"):
if isinstance(r.owner.op, HasInnerGraph) and r not in inner_graph_ops:
inner_graph_ops.append(r)
if hasattr(var.owner, "op"):
if isinstance(var.owner.op, HasInnerGraph) and var not in inner_graph_vars:
inner_graph_vars.append(var)
if print_op_info:
op_information.update(op_debug_information(r.owner.op, r.owner))
op_information.update(op_debug_information(var.owner.op, var.owner))
_debugprint(
r,
var,
depth=depth,
done=done,
print_type=print_type,
file=_file,
order=o,
ids=ids,
inner_graph_ops=inner_graph_ops,
topo_order=topo_order,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
stop_on_name=stop_on_name,
profile=p,
smap=s,
profile=profile,
storage_map=storage_map,
used_ids=used_ids,
op_information=op_information,
parent_node=r.owner,
parent_node=var.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
)
if len(inner_graph_ops) > 0:
if len(inner_graph_vars) > 0:
print("", file=_file)
new_prefix = " >"
new_prefix_child = " >"
print("Inner graphs:", file=_file)
for s in inner_graph_ops:
for ig_var in inner_graph_vars:
# This is a work-around to maintain backward compatibility
# (e.g. to only print inner graphs that have been compiled through
# a call to `Op.prepare_node`)
inner_fn = getattr(s.owner.op, "_fn", None)
inner_fn = getattr(ig_var.owner.op, "_fn", None)
if inner_fn:
# If the op was compiled, print the optimized version.
inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs
else:
inner_inputs = s.owner.op.inner_inputs
inner_outputs = s.owner.op.inner_outputs
inner_inputs = ig_var.owner.op.inner_inputs
inner_outputs = ig_var.owner.op.inner_outputs
outer_inputs = s.owner.inputs
outer_inputs = ig_var.owner.inputs
if hasattr(s.owner.op, "get_oinp_iinp_iout_oout_mappings"):
if hasattr(ig_var.owner.op, "get_oinp_iinp_iout_oout_mappings"):
inner_to_outer_inputs = {
inner_inputs[i]: outer_inputs[o]
for i, o in s.owner.op.get_oinp_iinp_iout_oout_mappings()[
for i, o in ig_var.owner.op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp"
].items()
}
......@@ -357,23 +371,25 @@ N.B.:
inner_to_outer_inputs = None
if print_op_info:
op_information.update(op_debug_information(s.owner.op, s.owner))
op_information.update(
op_debug_information(ig_var.owner.op, ig_var.owner)
)
print("", file=_file)
_debugprint(
s,
ig_var,
depth=depth,
done=done,
print_type=print_type,
file=_file,
ids=ids,
inner_graph_ops=inner_graph_ops,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
stop_on_name=stop_on_name,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
parent_node=s.owner,
parent_node=ig_var.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
......@@ -382,23 +398,23 @@ N.B.:
if print_fgraph_inputs:
for inp in inner_inputs:
_debugprint(
r=inp,
inp,
prefix="-",
depth=depth,
done=done,
print_type=print_type,
file=_file,
ids=ids,
id_type=id_type,
stop_on_name=stop_on_name,
inner_graph_ops=inner_graph_ops,
inner_graph_ops=inner_graph_vars,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
parent_node=s.owner,
parent_node=ig_var.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=s.owner,
inner_graph_node=ig_var.owner,
)
inner_to_outer_inputs = None
......@@ -406,34 +422,35 @@ N.B.:
if (
isinstance(getattr(out.owner, "op", None), HasInnerGraph)
and out not in inner_graph_ops
and out not in inner_graph_vars
):
inner_graph_ops.append(out)
inner_graph_vars.append(out)
_debugprint(
r=out,
out,
prefix=new_prefix,
depth=depth,
done=done,
print_type=print_type,
file=_file,
ids=ids,
id_type=id_type,
stop_on_name=stop_on_name,
prefix_child=new_prefix_child,
inner_graph_ops=inner_graph_ops,
inner_graph_ops=inner_graph_vars,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
parent_node=s.owner,
parent_node=ig_var.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=s.owner,
inner_graph_node=ig_var.owner,
)
if file is _file:
return file
elif file == "str":
assert isinstance(_file, StringIO)
return _file.getvalue()
else:
_file.flush()
......@@ -441,96 +458,87 @@ N.B.:
def _debugprint(
r: Variable,
var: Variable,
prefix: str = "",
depth: int = -1,
done: Optional[Dict[Apply, str]] = None,
done: Optional[Dict[Union[Literal["output"], Variable, Apply], str]] = None,
print_type: bool = False,
file: IOBase = sys.stdout,
file: TextIO = sys.stdout,
print_destroy_map: bool = False,
print_view_map: bool = False,
order: Optional[List[Variable]] = None,
ids: str = "CHAR",
topo_order: Optional[Sequence[Apply]] = None,
id_type: IDTypesType = "CHAR",
stop_on_name: bool = False,
prefix_child: Optional[str] = None,
inner_graph_ops: Optional[List[Variable]] = None,
profile: Optional[ProfileStats] = None,
inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None,
smap: Optional[StorageMapType] = None,
used_ids: Optional[Dict[Variable, str]] = None,
storage_map: Optional[StorageMapType] = None,
used_ids: Optional[Dict[Union[Literal["output"], Variable, Apply], str]] = None,
op_information: Optional[Dict[Apply, Dict[Variable, str]]] = None,
parent_node: Optional[Apply] = None,
print_op_info: bool = False,
inner_graph_node: Optional[Apply] = None,
) -> IOBase:
r"""Print the graph leading to `r`.
) -> TextIO:
r"""Print the graph represented by `var`.
Parameters
----------
r
var
A `Variable` instance.
prefix
Prefix to each line (typically some number of spaces).
depth
Print graph to this depth (``-1`` for unlimited).
done
A ``dict`` of `Apply` instances that have already been printed and
their associated printed ids.
Internal. Used to pass information when recursing.
See `debugprint`.
print_type
Whether to print the `Variable`'s type.
See `debugprint`.
file
File-like object to which to print.
print_destroy_map
Whether to print `Op` ``destroy_map``\s.
Whether to print the `Variable`'s type.
print_view_map
Whether to print `Op` ``view_map``\s.
order
Whether to print `Op` ``destroy_map``\s.
topo_order
If not empty will print the index in the toposort.
ids
Determines the type of identifier used for variables.
- ``"id"``: print the python id value,
- ``"int"``: print integer character,
- ``"CHAR"``: print capital character,
- ``"auto"``: print the ``auto_name`` value,
- ``""``: don't print an identifier.
id_type
See `debugprint`.
stop_on_name
When ``True``, if a node in the graph has a name, we don't print anything
below it.
Whether to print `Op` ``view_map``\s.
inner_graph_ops
A list of `Op`\s with inner graphs.
inner_to_outer_inputs
A dictionary mapping an `Op`'s inner-inputs to its outer-inputs.
smap
``None`` or the ``storage_map`` when printing an Aesara function.
storage_map
``None`` or the storage map (e.g. when printing an Aesara function).
used_ids
A map between nodes and their printed ids.
It wasn't always printed, but at least a reference to it was printed.
Internal. Used to pass information when recursing.
See `debugprint`.
op_information
Extra `Op`-level information to be added to variable print-outs.
parent_node
The parent node of `r`.
The parent node of `var`.
print_op_info
Print extra information provided by the relevant `Op`\s. For example,
print the tap information for `Scan` inputs and outputs.
See `debugprint`.
inner_graph_node
The inner-graph node in which `r` is contained.
The inner-graph node in which `var` is contained.
"""
if depth == 0:
return file
if order is None:
order = []
if topo_order is None:
topo_order = []
if done is None:
done = dict()
_done = dict()
else:
_done = done
if inner_graph_ops is None:
inner_graph_ops = []
if print_type:
type_str = f" <{r.type}>"
type_str = f" <{var.type}>"
else:
type_str = ""
......@@ -538,92 +546,100 @@ def _debugprint(
prefix_child = prefix
if used_ids is None:
used_ids = dict()
_used_ids = dict()
else:
_used_ids = used_ids
if op_information is None:
op_information = {}
def get_id_str(obj, get_printed=True) -> str:
def get_id_str(
obj: Union[Literal["output"], Apply, Variable], get_printed: bool = True
) -> str:
id_str: str = ""
if obj in used_ids:
id_str = used_ids[obj]
if obj in _used_ids:
id_str = _used_ids[obj]
elif obj == "output":
id_str = "output"
elif ids == "id":
id_str = f"[id {id(r)}]"
elif ids == "int":
id_str = f"[id {len(used_ids)}]"
elif ids == "CHAR":
id_str = f"[id {char_from_number(len(used_ids))}]"
elif ids == "auto":
id_str = f"[id {r.auto_name}]"
elif ids == "":
elif id_type == "id":
id_str = f"[id {id(var)}]"
elif id_type == "int":
id_str = f"[id {len(_used_ids)}]"
elif id_type == "CHAR":
id_str = f"[id {char_from_number(len(_used_ids))}]"
elif id_type == "auto":
id_str = f"[id {var.auto_name}]"
elif id_type == "":
id_str = ""
if get_printed:
done[obj] = id_str
used_ids[obj] = id_str
_done[obj] = id_str
_used_ids[obj] = id_str
return id_str
if hasattr(r.owner, "op"):
if var.owner:
# This variable is the output of a computation, so just print out the
# `Apply` node
a = r.owner
node = var.owner
r_name = getattr(r, "name", "")
var_name = getattr(var, "name", "")
if r_name is None:
r_name = ""
if r_name:
r_name = f" '{r_name}'"
if var_name is None:
var_name = ""
if var_name:
var_name = f" '{var_name}'"
if print_destroy_map and r.owner.op.destroy_map:
destroy_map_str = f" d={r.owner.op.destroy_map}"
if print_destroy_map and node.op.destroy_map:
destroy_map_str = f" d={node.op.destroy_map}"
else:
destroy_map_str = ""
if print_view_map and r.owner.op.view_map:
view_map_str = f" v={r.owner.op.view_map}"
if print_view_map and node.op.view_map:
view_map_str = f" v={node.op.view_map}"
else:
view_map_str = ""
if order:
o = f" {order.index(r.owner)}"
if topo_order:
o = f" {topo_order.index(node)}"
else:
o = ""
already_done = a in done
id_str = get_id_str(a)
already_done = node in _done
id_str = get_id_str(node)
if len(a.outputs) == 1:
idx = ""
if len(node.outputs) == 1:
output_idx = ""
else:
idx = f".{a.outputs.index(r)}"
output_idx = f".{node.outputs.index(var)}"
if id_str:
id_str = f" {id_str}"
if smap and a.outputs[0] in smap:
data = f" {smap[a.outputs[0]]}"
if storage_map and node.outputs[0] in storage_map:
data = f" {storage_map[node.outputs[0]]}"
else:
data = ""
var_output = f"{prefix}{a.op}{idx}{id_str}{type_str}{r_name}{destroy_map_str}{view_map_str}{o}{data}"
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
if print_op_info and r.owner not in op_information:
op_information.update(op_debug_information(r.owner.op, r.owner))
if print_op_info and node not in op_information:
op_information.update(op_debug_information(node.op, node))
node_info = op_information.get(parent_node) or op_information.get(r.owner)
if node_info and r in node_info:
var_output = f"{var_output} ({node_info[r]})"
node_info = (
parent_node and op_information.get(parent_node)
) or op_information.get(node)
if node_info and var in node_info:
var_output = f"{var_output} ({node_info[var]})"
if profile is None or a not in profile.apply_time:
if profile is None:
print(var_output, file=file)
else:
op_time = profile.apply_time[a]
elif profile.apply_time and node not in profile.apply_time:
print(var_output, file=file)
elif profile.apply_time and node in profile.apply_time:
op_time = profile.apply_time[node]
op_time_percent = (op_time / profile.fct_call_time) * 100
tot_time_dict = profile.compute_total_times()
tot_time = tot_time_dict[a]
tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100
tot_time = tot_time_dict[node]
tot_time_percent = (tot_time_dict[node] / profile.fct_call_time) * 100
print(
"%s --> %8.2es %4.1f%% %8.2es %4.1f%%"
......@@ -638,40 +654,40 @@ def _debugprint(
)
if not already_done and (
not stop_on_name or not (hasattr(r, "name") and r.name is not None)
not stop_on_name or not (hasattr(var, "name") and var.name is not None)
):
new_prefix = prefix_child + " |"
new_prefix_child = prefix_child + " |"
for idx, i in enumerate(a.inputs):
if idx == len(a.inputs) - 1:
for in_idx, in_var in enumerate(node.inputs):
if in_idx == len(node.inputs) - 1:
new_prefix_child = prefix_child + " "
if hasattr(i, "owner") and hasattr(i.owner, "op"):
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if (
isinstance(i.owner.op, HasInnerGraph)
and i not in inner_graph_ops
isinstance(in_var.owner.op, HasInnerGraph)
and in_var not in inner_graph_ops
):
inner_graph_ops.append(i)
inner_graph_ops.append(in_var)
_debugprint(
i,
in_var,
new_prefix,
depth=depth - 1,
done=done,
done=_done,
print_type=print_type,
file=file,
order=order,
ids=ids,
topo_order=topo_order,
id_type=id_type,
stop_on_name=stop_on_name,
prefix_child=new_prefix_child,
inner_graph_ops=inner_graph_ops,
profile=profile,
inner_to_outer_inputs=inner_to_outer_inputs,
smap=smap,
used_ids=used_ids,
storage_map=storage_map,
used_ids=_used_ids,
op_information=op_information,
parent_node=a,
parent_node=node,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
......@@ -679,38 +695,38 @@ def _debugprint(
)
else:
id_str = get_id_str(r)
id_str = get_id_str(var)
if id_str:
id_str = f" {id_str}"
if smap and r in smap:
data = f" {smap[r]}"
if storage_map and var in storage_map:
data = f" {storage_map[var]}"
else:
data = ""
var_output = f"{prefix}{r}{id_str}{type_str}{data}"
var_output = f"{prefix}{var}{id_str}{type_str}{data}"
if print_op_info and r.owner and r.owner not in op_information:
op_information.update(op_debug_information(r.owner.op, r.owner))
if print_op_info and var.owner and var.owner not in op_information:
op_information.update(op_debug_information(var.owner.op, var.owner))
if inner_to_outer_inputs is not None and r in inner_to_outer_inputs:
if inner_to_outer_inputs is not None and var in inner_to_outer_inputs:
outer_r = inner_to_outer_inputs[r]
outer_var = inner_to_outer_inputs[var]
if outer_r.owner:
outer_id_str = get_id_str(outer_r.owner)
if outer_var.owner:
outer_id_str = get_id_str(outer_var.owner)
else:
outer_id_str = get_id_str(outer_r)
outer_id_str = get_id_str(outer_var)
var_output = f"{var_output} -> {outer_id_str}"
# TODO: This entire approach will only print `Op` info for two levels
# of nesting.
for node in dict.fromkeys([inner_graph_node, parent_node, r.owner]):
for node in dict.fromkeys([inner_graph_node, parent_node, var.owner]):
node_info = op_information.get(node)
if node_info and r in node_info:
var_output = f"{var_output} ({node_info[r]})"
if node_info and var in node_info:
var_output = f"{var_output} ({node_info[var]})"
print(var_output, file=file)
......@@ -1827,18 +1843,20 @@ def hex_digest(x):
def get_node_by_id(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR"
) -> Optional[Union[Variable, Apply]]:
graphs: Union[Variable, Sequence[Variable], Function, FunctionGraph],
target_var_id: str,
id_types: IDTypesType = "CHAR",
) -> Optional[Union[Literal["output"], Variable, Apply]]:
r"""Get `Apply` nodes or `Variable`\s in a graph using their `debugprint` IDs.
Parameters
----------
graphs:
graphs
The graph, or graphs, to search.
target_var_id:
target_var_id
The name to search for.
ids:
The ID scheme to use (see `debugprint.`).
id_types
The ID scheme to use (see `debugprint`).
Returns
-------
......@@ -1847,12 +1865,9 @@ def get_node_by_id(
"""
from aesara.printing import debugprint
if isinstance(graphs, Variable):
graphs = (graphs,)
used_ids: Dict[Variable, str] = {}
used_ids: Dict[Union[Literal["output"], Variable, Apply], str] = {}
_ = debugprint(graphs, file="str", used_ids=used_ids, ids=ids)
_ = debugprint(graphs, file="str", used_ids=used_ids, id_type=id_types)
id_to_node = {v: k for k, v in used_ids.items()}
......
......@@ -227,10 +227,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.printing]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.tensor.nnet.conv3d2d]
ignore_errors = True
check_untyped_defs = False
......
......@@ -133,9 +133,8 @@ def test_debugprint():
s = StringIO()
debugprint(G, file=s)
# test ids=int
s = StringIO()
debugprint(G, file=s, ids="int")
debugprint(G, file=s, id_type="int")
s = s.getvalue()
# The additional white space are needed!
reference = (
......@@ -155,9 +154,8 @@ def test_debugprint():
assert s == reference
# test ids=CHAR
s = StringIO()
debugprint(G, file=s, ids="CHAR")
debugprint(G, file=s, id_type="CHAR")
s = s.getvalue()
# The additional white space are needed!
reference = (
......@@ -177,9 +175,8 @@ def test_debugprint():
assert s == reference
# test ids=CHAR, stop_on_name=True
s = StringIO()
debugprint(G, file=s, ids="CHAR", stop_on_name=True)
debugprint(G, file=s, id_type="CHAR", stop_on_name=True)
s = s.getvalue()
# The additional white space are needed!
reference = (
......@@ -197,9 +194,8 @@ def test_debugprint():
assert s == reference
# test ids=
s = StringIO()
debugprint(G, file=s, ids="")
debugprint(G, file=s, id_type="")
s = s.getvalue()
# The additional white space are needed!
reference = (
......@@ -221,7 +217,7 @@ def test_debugprint():
# test print_storage=True
s = StringIO()
debugprint(g, file=s, ids="", print_storage=True)
debugprint(g, file=s, id_type="", print_storage=True)
s = s.getvalue()
reference = (
"\n".join(
......@@ -246,7 +242,7 @@ def test_debugprint():
debugprint(
aesara.function([A, B, D, J], A + (B.dot(J) - D), mode="FAST_RUN"),
file=s,
ids="",
id_type="",
print_destroy_map=True,
print_view_map=True,
)
......@@ -270,7 +266,7 @@ def test_debugprint():
]
def test_debugprint_ids():
def test_debugprint_id_type():
a_at = dvector()
b_at = dmatrix()
......@@ -278,7 +274,7 @@ def test_debugprint_ids():
e_at = d_at + a_at
s = StringIO()
debugprint(e_at, ids="auto", file=s)
debugprint(e_at, id_type="auto", file=s)
s = s.getvalue()
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论