Unverified 提交 b12dc30a authored 作者: Aarsh Wankar's avatar Aarsh Wankar 提交者: GitHub

Add `print_shape` and `print_memory_map` option to `debugprint` (#1236)

* Added print_shape option to debugprint and simplify __str__ logic in TensorType * Add print_memory_map option to debugprint to enable destroy and view maps
上级 4c27eb9b
...@@ -89,6 +89,7 @@ def debugprint( ...@@ -89,6 +89,7 @@ def debugprint(
| Sequence[Variable | Apply | Function | FunctionGraph], | Sequence[Variable | Apply | Function | FunctionGraph],
depth: int = -1, depth: int = -1,
print_type: bool = False, print_type: bool = False,
print_shape: bool = False,
file: Literal["str"] | TextIO | None = None, file: Literal["str"] | TextIO | None = None,
id_type: IDTypesType = "CHAR", id_type: IDTypesType = "CHAR",
stop_on_name: bool = False, stop_on_name: bool = False,
...@@ -98,6 +99,7 @@ def debugprint( ...@@ -98,6 +99,7 @@ def debugprint(
print_op_info: bool = False, print_op_info: bool = False,
print_destroy_map: bool = False, print_destroy_map: bool = False,
print_view_map: bool = False, print_view_map: bool = False,
print_memory_map: bool = False,
print_fgraph_inputs: bool = False, print_fgraph_inputs: bool = False,
) -> str | TextIO: ) -> str | TextIO:
r"""Print a graph as text. r"""Print a graph as text.
...@@ -123,6 +125,8 @@ def debugprint( ...@@ -123,6 +125,8 @@ def debugprint(
Print graph to this depth (``-1`` for unlimited). Print graph to this depth (``-1`` for unlimited).
print_type print_type
If ``True``, print the `Type`\s of each `Variable` in the graph. If ``True``, print the `Type`\s of each `Variable` in the graph.
print_shape
If ``True``, print the shape of each `Variable` in the graph.
file file
When `file` extends `TextIO`, print to it; when `file` is When `file` extends `TextIO`, print to it; when `file` is
equal to ``"str"``, return a string; when `file` is ``None``, print to equal to ``"str"``, return a string; when `file` is ``None``, print to
...@@ -153,6 +157,8 @@ def debugprint( ...@@ -153,6 +157,8 @@ def debugprint(
Whether to print the `destroy_map`\s of printed objects Whether to print the `destroy_map`\s of printed objects
print_view_map print_view_map
Whether to print the `view_map`\s of printed objects Whether to print the `view_map`\s of printed objects
print_memory_map
Whether to set both `print_destroy_map` and `print_view_map` to ``True``.
print_fgraph_inputs print_fgraph_inputs
Print the inputs of `FunctionGraph`\s. Print the inputs of `FunctionGraph`\s.
...@@ -177,6 +183,10 @@ def debugprint( ...@@ -177,6 +183,10 @@ def debugprint(
if used_ids is None: if used_ids is None:
used_ids = dict() used_ids = dict()
if print_memory_map:
print_destroy_map = True
print_view_map = True
inputs_to_print = [] inputs_to_print = []
outputs_to_print = [] outputs_to_print = []
profile_list: list[Any | None] = [] profile_list: list[Any | None] = []
...@@ -265,6 +275,7 @@ N.B.: ...@@ -265,6 +275,7 @@ N.B.:
depth=depth, depth=depth,
done=done, done=done,
print_type=print_type, print_type=print_type,
print_shape=print_shape,
file=_file, file=_file,
id_type=id_type, id_type=id_type,
inner_graph_ops=inner_graph_vars, inner_graph_ops=inner_graph_vars,
...@@ -295,6 +306,7 @@ N.B.: ...@@ -295,6 +306,7 @@ N.B.:
depth=depth, depth=depth,
done=done, done=done,
print_type=print_type, print_type=print_type,
print_shape=print_shape,
file=_file, file=_file,
topo_order=topo_order, topo_order=topo_order,
id_type=id_type, id_type=id_type,
...@@ -365,6 +377,7 @@ N.B.: ...@@ -365,6 +377,7 @@ N.B.:
depth=depth, depth=depth,
done=done, done=done,
print_type=print_type, print_type=print_type,
print_shape=print_shape,
file=_file, file=_file,
id_type=id_type, id_type=id_type,
inner_graph_ops=inner_graph_vars, inner_graph_ops=inner_graph_vars,
...@@ -387,6 +400,7 @@ N.B.: ...@@ -387,6 +400,7 @@ N.B.:
depth=depth, depth=depth,
done=done, done=done,
print_type=print_type, print_type=print_type,
print_shape=print_shape,
file=_file, file=_file,
id_type=id_type, id_type=id_type,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
...@@ -421,6 +435,7 @@ N.B.: ...@@ -421,6 +435,7 @@ N.B.:
depth=depth, depth=depth,
done=done, done=done,
print_type=print_type, print_type=print_type,
print_shape=print_shape,
file=_file, file=_file,
id_type=id_type, id_type=id_type,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
...@@ -452,6 +467,7 @@ def _debugprint( ...@@ -452,6 +467,7 @@ def _debugprint(
depth: int = -1, depth: int = -1,
done: dict[Literal["output"] | Variable | Apply, str] | None = None, done: dict[Literal["output"] | Variable | Apply, str] | None = None,
print_type: bool = False, print_type: bool = False,
print_shape: bool = False,
file: TextIO = sys.stdout, file: TextIO = sys.stdout,
print_destroy_map: bool = False, print_destroy_map: bool = False,
print_view_map: bool = False, print_view_map: bool = False,
...@@ -484,6 +500,8 @@ def _debugprint( ...@@ -484,6 +500,8 @@ def _debugprint(
See `debugprint`. See `debugprint`.
print_type print_type
See `debugprint`. See `debugprint`.
print_shape
See `debugprint`.
file file
File-like object to which to print. File-like object to which to print.
print_destroy_map print_destroy_map
...@@ -532,6 +550,11 @@ def _debugprint( ...@@ -532,6 +550,11 @@ def _debugprint(
else: else:
type_str = "" type_str = ""
if print_shape and hasattr(var.type, "shape"):
shape_str = f" shape={str(var.type.shape).replace('None', '?')}"
else:
shape_str = ""
if prefix_child is None: if prefix_child is None:
prefix_child = prefix prefix_child = prefix
...@@ -612,7 +635,7 @@ def _debugprint( ...@@ -612,7 +635,7 @@ def _debugprint(
if is_inner_graph_header: if is_inner_graph_header:
var_output = f"{prefix}{node.op}{id_str}{destroy_map_str}{view_map_str}{o}" var_output = f"{prefix}{node.op}{id_str}{destroy_map_str}{view_map_str}{o}"
else: else:
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}" var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{shape_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
if print_op_info and node not in op_information: if print_op_info and node not in op_information:
op_information.update(op_debug_information(node.op, node)) op_information.update(op_debug_information(node.op, node))
...@@ -662,6 +685,7 @@ def _debugprint( ...@@ -662,6 +685,7 @@ def _debugprint(
depth=depth - 1, depth=depth - 1,
done=_done, done=_done,
print_type=print_type, print_type=print_type,
print_shape=print_shape,
file=file, file=file,
topo_order=topo_order, topo_order=topo_order,
id_type=id_type, id_type=id_type,
...@@ -692,7 +716,7 @@ def _debugprint( ...@@ -692,7 +716,7 @@ def _debugprint(
else: else:
data = "" data = ""
var_output = f"{prefix}{var}{id_str}{type_str}{data}" var_output = f"{prefix}{var}{id_str}{type_str}{shape_str}{data}"
if print_op_info and var.owner and var.owner not in op_information: if print_op_info and var.owner and var.owner not in op_information:
op_information.update(op_debug_information(var.owner.op, var.owner)) op_information.update(op_debug_information(var.owner.op, var.owner))
......
...@@ -399,22 +399,13 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -399,22 +399,13 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
else: else:
shape = self.shape shape = self.shape
len_shape = len(shape) len_shape = len(shape)
formatted_shape = str(shape).replace("None", "?")
def shape_str(s):
if s is None:
return "?"
else:
return str(s)
formatted_shape = ", ".join(shape_str(s) for s in shape)
if len_shape == 1:
formatted_shape += ","
if len_shape > 2: if len_shape > 2:
name = f"Tensor{len_shape}" name = f"Tensor{len_shape}"
else: else:
name = ("Scalar", "Vector", "Matrix")[len_shape] name = ("Scalar", "Vector", "Matrix")[len_shape]
return f"{name}({self.dtype}, shape=({formatted_shape}))" return f"{name}({self.dtype}, shape={formatted_shape})"
def __repr__(self): def __repr__(self):
return f"TensorType({self.dtype}, shape={self.shape})" return f"TensorType({self.dtype}, shape={self.shape})"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论