提交 33743b89 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move aesara.debugmode.debugprint to aesara.printing

上级 5335e729
...@@ -36,6 +36,7 @@ from aesara.graph.op import COp, Op, ops_with_inner_function ...@@ -36,6 +36,7 @@ from aesara.graph.op import COp, Op, ops_with_inner_function
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.link.basic import Container, LocalLinker from aesara.link.basic import Container, LocalLinker
from aesara.link.utils import map_storage, raise_with_op from aesara.link.utils import map_storage, raise_with_op
from aesara.printing import _debugprint
from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function
...@@ -326,7 +327,7 @@ class InvalidValueError(DebugModeError): ...@@ -326,7 +327,7 @@ class InvalidValueError(DebugModeError):
client_node = self.client_node client_node = self.client_node
hint = self.hint hint = self.hint
specific_hint = self.specific_hint specific_hint = self.specific_hint
context = debugprint(r, prefix=" ", depth=12, file=StringIO()).getvalue() context = _debugprint(r, prefix=" ", depth=12, file=StringIO()).getvalue()
return f"""InvalidValueError return f"""InvalidValueError
type(variable) = {type_r} type(variable) = {type_r}
variable = {r} variable = {r}
...@@ -409,273 +410,6 @@ def str_diagnostic(expected, value, rtol, atol): ...@@ -409,273 +410,6 @@ def str_diagnostic(expected, value, rtol, atol):
return sio.getvalue() return sio.getvalue()
def char_from_number(number):
"""
Converts number to string by rendering it in base 26 using
capital letters as digits.
"""
base = 26
rval = ""
if number == 0:
rval = "A"
while number != 0:
remainder = number % base
new_char = chr(ord("A") + remainder)
rval = new_char + rval
number //= base
return rval
def debugprint(
r,
prefix="",
depth=-1,
done=None,
print_type=False,
file=sys.stdout,
print_destroy_map=False,
print_view_map=False,
order=None,
ids="CHAR",
stop_on_name=False,
prefix_child=None,
scan_ops=None,
profile=None,
scan_inner_to_outer_inputs=None,
smap=None,
used_ids=None,
):
"""
Print the graph leading to `r` to given depth.
Parameters
----------
r
Variable instance.
prefix
Prefix to each line (typically some number of spaces).
depth
Maximum recursion depth (Default -1 for unlimited).
done
Internal. Used to pass information when recursing.
Dict of Apply instances that have already been printed and their
associated printed ids.
print_type
Whether to print the Variable type after the other infos.
file
File-like object to which to print.
print_destroy_map
Whether to print the op destroy_map after other info.
print_view_map
Whether to print the op view_map after other info.
order
If not empty will print the index in the toposort.
ids
How do we print the identifier of the variable :
id - print the python id value,
int - print integer character,
CHAR - print capital character,
"" - don't print an identifier.
stop_on_name
When True, if a node in the graph has a name, we don't print anything
below it.
scan_ops
Scan ops in the graph will be added inside this list for later printing
purposes.
scan_inner_to_outer_inputs
A dictionary mapping a scan ops inner function inputs to the scan op
inputs (outer inputs) for printing purposes.
smap
None or the storage_map when printing an Aesara function.
used_ids
Internal. Used to pass information when recursing.
It is a dict from obj to the id used for it.
It wasn't always printed, but at least a reference to it was printed.
"""
if depth == 0:
return
if order is None:
order = []
if done is None:
done = dict()
if scan_ops is None:
scan_ops = []
if print_type:
type_str = f" <{r.type}>"
else:
type_str = ""
if prefix_child is None:
prefix_child = prefix
if used_ids is None:
used_ids = dict()
def get_id_str(obj, get_printed=True) -> str:
id_str: str = ""
if obj in used_ids:
id_str = used_ids[obj]
elif obj == "output":
id_str = "output"
elif ids == "id":
id_str = f"[id {id(r)}]"
elif ids == "int":
id_str = f"[id {len(used_ids)}]"
elif ids == "CHAR":
id_str = f"[id {char_from_number(len(used_ids))}]"
elif ids == "":
id_str = ""
if get_printed:
done[obj] = id_str
used_ids[obj] = id_str
return id_str
if hasattr(r.owner, "op"):
# this variable is the output of computation,
# so just print out the apply
a = r.owner
r_name = getattr(r, "name", "")
# normally if the name isn't set, it'll be None, so
# r_name is None here
if r_name is None:
r_name = ""
if print_destroy_map:
destroy_map_str = str(r.owner.op.destroy_map)
else:
destroy_map_str = ""
if print_view_map:
view_map_str = str(r.owner.op.view_map)
else:
view_map_str = ""
if destroy_map_str and destroy_map_str != "{}":
destroy_map_str = "d=" + destroy_map_str
if view_map_str and view_map_str != "{}":
view_map_str = "v=" + view_map_str
o = ""
if order:
o = str(order.index(r.owner))
already_printed = a in done # get_id_str put it in the dict
id_str = get_id_str(a)
if len(a.outputs) == 1:
idx = ""
else:
idx = f".{a.outputs.index(r)}"
data = ""
if smap:
data = " " + str(smap.get(a.outputs[0], ""))
clients = ""
if profile is None or a not in profile.apply_time:
print(
f"{prefix}{a.op}{idx} {id_str}{type_str} '{r_name}' {destroy_map_str} {view_map_str} {o}{data}{clients}",
file=file,
)
else:
op_time = profile.apply_time[a]
op_time_percent = (op_time / profile.fct_call_time) * 100
tot_time_dict = profile.compute_total_times()
tot_time = tot_time_dict[a]
tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100
if len(a.outputs) == 1:
idx = ""
else:
idx = f".{a.outputs.index(r)}"
print(
"%s%s%s %s%s '%s' %s %s %s%s%s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (
prefix,
a.op,
idx,
id_str,
type_str,
r_name,
destroy_map_str,
view_map_str,
o,
data,
clients,
op_time,
op_time_percent,
tot_time,
tot_time_percent,
),
file=file,
)
if not already_printed:
if not stop_on_name or not (hasattr(r, "name") and r.name is not None):
new_prefix = prefix_child + " |"
new_prefix_child = prefix_child + " |"
for idx, i in enumerate(a.inputs):
if idx == len(a.inputs) - 1:
new_prefix_child = prefix_child + " "
if hasattr(i, "owner") and hasattr(i.owner, "op"):
from aesara.scan.op import Scan
if isinstance(i.owner.op, Scan):
scan_ops.append(i)
debugprint(
i,
new_prefix,
depth=depth - 1,
done=done,
print_type=print_type,
file=file,
order=order,
ids=ids,
stop_on_name=stop_on_name,
prefix_child=new_prefix_child,
scan_ops=scan_ops,
profile=profile,
scan_inner_to_outer_inputs=scan_inner_to_outer_inputs,
smap=smap,
used_ids=used_ids,
)
else:
if scan_inner_to_outer_inputs is not None and r in scan_inner_to_outer_inputs:
id_str = get_id_str(r)
outer_r = scan_inner_to_outer_inputs[r]
if hasattr(outer_r.owner, "op"):
outer_id_str = get_id_str(outer_r.owner)
else:
outer_id_str = get_id_str(outer_r)
print(
f"{prefix}{r} {id_str}{type_str} -> {outer_id_str}",
file=file,
)
else:
# this is an input variable
data = ""
if smap:
data = " " + str(smap.get(r, ""))
id_str = get_id_str(r)
print(f"{prefix}{r} {id_str}{type_str}{data}", file=file)
return file
def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
""" """
Create a FunctionGraph for debugging. Create a FunctionGraph for debugging.
...@@ -1656,7 +1390,7 @@ class _VariableEquivalenceTracker: ...@@ -1656,7 +1390,7 @@ class _VariableEquivalenceTracker:
( (
reason, reason,
r, r,
debugprint( _debugprint(
r, r,
prefix=" ", prefix=" ",
depth=6, depth=6,
...@@ -1665,7 +1399,7 @@ class _VariableEquivalenceTracker: ...@@ -1665,7 +1399,7 @@ class _VariableEquivalenceTracker:
print_type=True, print_type=True,
used_ids=used_ids, used_ids=used_ids,
).getvalue(), ).getvalue(),
debugprint( _debugprint(
new_r, new_r,
prefix=" ", prefix=" ",
depth=6, depth=6,
......
...@@ -110,7 +110,7 @@ class BadOptimization(Exception): ...@@ -110,7 +110,7 @@ class BadOptimization(Exception):
used_ids = dict() used_ids = dict()
if isinstance(old_r, Variable): if isinstance(old_r, Variable):
self.old_graph = aesara.compile.debugmode.debugprint( self.old_graph = aesara.printing._debugprint(
old_r, old_r,
prefix=" ", prefix=" ",
depth=6, depth=6,
...@@ -123,7 +123,7 @@ class BadOptimization(Exception): ...@@ -123,7 +123,7 @@ class BadOptimization(Exception):
self.old_graph = None self.old_graph = None
if isinstance(new_r, Variable): if isinstance(new_r, Variable):
self.new_graph = aesara.compile.debugmode.debugprint( self.new_graph = aesara.printing._debugprint(
new_r, new_r,
prefix=" ", prefix=" ",
depth=6, depth=6,
......
...@@ -14,7 +14,7 @@ from io import StringIO ...@@ -14,7 +14,7 @@ from io import StringIO
import numpy as np import numpy as np
from aesara.compile import Function, SharedVariable, debugmode from aesara.compile import Function, SharedVariable
from aesara.compile.io import In, Out from aesara.compile.io import In, Out
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, graph_inputs, io_toposort from aesara.graph.basic import Apply, Constant, Variable, graph_inputs, io_toposort
...@@ -60,6 +60,25 @@ _logger = logging.getLogger("aesara.printing") ...@@ -60,6 +60,25 @@ _logger = logging.getLogger("aesara.printing")
VALID_ASSOC = {"left", "right", "either"} VALID_ASSOC = {"left", "right", "either"}
def char_from_number(number):
"""Convert numbers to strings by rendering it in base 26 using capital letters as digits."""
base = 26
rval = ""
if number == 0:
rval = "A"
while number != 0:
remainder = number % base
new_char = chr(ord("A") + remainder)
rval = new_char + rval
number //= base
return rval
def debugprint( def debugprint(
obj, obj,
depth=-1, depth=-1,
...@@ -205,7 +224,7 @@ N.B.: ...@@ -205,7 +224,7 @@ N.B.:
if hasattr(r.owner, "op") and isinstance(r.owner.op, Scan): if hasattr(r.owner, "op") and isinstance(r.owner.op, Scan):
scan_ops.append(r) scan_ops.append(r)
debugmode.debugprint( _debugprint(
r, r,
depth=depth, depth=depth,
done=done, done=done,
...@@ -243,7 +262,7 @@ N.B.: ...@@ -243,7 +262,7 @@ N.B.:
} }
print("", file=_file) print("", file=_file)
debugmode.debugprint( _debugprint(
s, s,
depth=depth, depth=depth,
done=done, done=done,
...@@ -266,7 +285,7 @@ N.B.: ...@@ -266,7 +285,7 @@ N.B.:
if isinstance(i.owner.op, Scan): if isinstance(i.owner.op, Scan):
scan_ops.append(i) scan_ops.append(i)
debugmode.debugprint( _debugprint(
r=i, r=i,
prefix=new_prefix, prefix=new_prefix,
depth=depth, depth=depth,
...@@ -289,6 +308,249 @@ N.B.: ...@@ -289,6 +308,249 @@ N.B.:
_file.flush() _file.flush()
def _debugprint(
r,
prefix="",
depth=-1,
done=None,
print_type=False,
file=sys.stdout,
print_destroy_map=False,
print_view_map=False,
order=None,
ids="CHAR",
stop_on_name=False,
prefix_child=None,
scan_ops=None,
profile=None,
scan_inner_to_outer_inputs=None,
smap=None,
used_ids=None,
):
"""Print the graph leading to `r`.
Parameters
----------
r
Variable instance.
prefix
Prefix to each line (typically some number of spaces).
depth
Maximum recursion depth (Default -1 for unlimited).
done
Internal. Used to pass information when recursing.
Dict of Apply instances that have already been printed and their
associated printed ids.
print_type
Whether to print the Variable type after the other infos.
file
File-like object to which to print.
print_destroy_map
Whether to print the op destroy_map after other info.
print_view_map
Whether to print the op view_map after other info.
order
If not empty will print the index in the toposort.
ids
Determines the type of identifier used for variables.
- ``"id"``: print the python id value,
- ``"int"``: print integer character,
- ``"CHAR"``: print capital character,
- ``""``: don't print an identifier.
stop_on_name
When True, if a node in the graph has a name, we don't print anything
below it.
scan_ops
Scan ops in the graph will be added inside this list for later printing
purposes.
scan_inner_to_outer_inputs
A dictionary mapping a scan ops inner function inputs to the scan op
inputs (outer inputs) for printing purposes.
smap
None or the storage_map when printing an Aesara function.
used_ids
Internal. Used to pass information when recursing.
It is a dict from obj to the id used for it.
It wasn't always printed, but at least a reference to it was printed.
"""
if depth == 0:
return
if order is None:
order = []
if done is None:
done = dict()
if scan_ops is None:
scan_ops = []
if print_type:
type_str = f" <{r.type}>"
else:
type_str = ""
if prefix_child is None:
prefix_child = prefix
if used_ids is None:
used_ids = dict()
def get_id_str(obj, get_printed=True) -> str:
id_str: str = ""
if obj in used_ids:
id_str = used_ids[obj]
elif obj == "output":
id_str = "output"
elif ids == "id":
id_str = f"[id {id(r)}]"
elif ids == "int":
id_str = f"[id {len(used_ids)}]"
elif ids == "CHAR":
id_str = f"[id {char_from_number(len(used_ids))}]"
elif ids == "":
id_str = ""
if get_printed:
done[obj] = id_str
used_ids[obj] = id_str
return id_str
if hasattr(r.owner, "op"):
# this variable is the output of computation,
# so just print out the apply
a = r.owner
r_name = getattr(r, "name", "")
# normally if the name isn't set, it'll be None, so
# r_name is None here
if r_name is None:
r_name = ""
if print_destroy_map:
destroy_map_str = str(r.owner.op.destroy_map)
else:
destroy_map_str = ""
if print_view_map:
view_map_str = str(r.owner.op.view_map)
else:
view_map_str = ""
if destroy_map_str and destroy_map_str != "{}":
destroy_map_str = "d=" + destroy_map_str
if view_map_str and view_map_str != "{}":
view_map_str = "v=" + view_map_str
o = ""
if order:
o = str(order.index(r.owner))
already_printed = a in done # get_id_str put it in the dict
id_str = get_id_str(a)
if len(a.outputs) == 1:
idx = ""
else:
idx = f".{a.outputs.index(r)}"
data = ""
if smap:
data = " " + str(smap.get(a.outputs[0], ""))
clients = ""
if profile is None or a not in profile.apply_time:
print(
f"{prefix}{a.op}{idx} {id_str}{type_str} '{r_name}' {destroy_map_str} {view_map_str} {o}{data}{clients}",
file=file,
)
else:
op_time = profile.apply_time[a]
op_time_percent = (op_time / profile.fct_call_time) * 100
tot_time_dict = profile.compute_total_times()
tot_time = tot_time_dict[a]
tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100
if len(a.outputs) == 1:
idx = ""
else:
idx = f".{a.outputs.index(r)}"
print(
"%s%s%s %s%s '%s' %s %s %s%s%s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (
prefix,
a.op,
idx,
id_str,
type_str,
r_name,
destroy_map_str,
view_map_str,
o,
data,
clients,
op_time,
op_time_percent,
tot_time,
tot_time_percent,
),
file=file,
)
if not already_printed:
if not stop_on_name or not (hasattr(r, "name") and r.name is not None):
new_prefix = prefix_child + " |"
new_prefix_child = prefix_child + " |"
for idx, i in enumerate(a.inputs):
if idx == len(a.inputs) - 1:
new_prefix_child = prefix_child + " "
if hasattr(i, "owner") and hasattr(i.owner, "op"):
from aesara.scan.op import Scan
if isinstance(i.owner.op, Scan):
scan_ops.append(i)
_debugprint(
i,
new_prefix,
depth=depth - 1,
done=done,
print_type=print_type,
file=file,
order=order,
ids=ids,
stop_on_name=stop_on_name,
prefix_child=new_prefix_child,
scan_ops=scan_ops,
profile=profile,
scan_inner_to_outer_inputs=scan_inner_to_outer_inputs,
smap=smap,
used_ids=used_ids,
)
else:
if scan_inner_to_outer_inputs is not None and r in scan_inner_to_outer_inputs:
id_str = get_id_str(r)
outer_r = scan_inner_to_outer_inputs[r]
if hasattr(outer_r.owner, "op"):
outer_id_str = get_id_str(outer_r.owner)
else:
outer_id_str = get_id_str(outer_r)
print(
f"{prefix}{r} {id_str}{type_str} -> {outer_id_str}",
file=file,
)
else:
# this is an input variable
data = ""
if smap:
data = " " + str(smap.get(r, ""))
id_str = get_id_str(r)
print(f"{prefix}{r} {id_str}{type_str}{data}", file=file)
return file
def _print_fn(op, xin): def _print_fn(op, xin):
for attr in op.attrs: for attr in op.attrs:
temp = getattr(xin, attr) temp = getattr(xin, attr)
...@@ -1176,7 +1438,7 @@ class _TagGenerator: ...@@ -1176,7 +1438,7 @@ class _TagGenerator:
self.cur_tag_number = 0 self.cur_tag_number = 0
def get_tag(self): def get_tag(self):
rval = debugmode.char_from_number(self.cur_tag_number) rval = char_from_number(self.cur_tag_number)
self.cur_tag_number += 1 self.cur_tag_number += 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论