Unverified 提交 b12dc30a authored 作者: Aarsh Wankar's avatar Aarsh Wankar 提交者: GitHub

Add `print_shape` and `print_memory_map` option to `debugprint` (#1236)

* Added print_shape option to debugprint and simplify __str__ logic in TensorType * Add print_memory_map option to debugprint to enable destroy and view maps
上级 4c27eb9b
......@@ -89,6 +89,7 @@ def debugprint(
| Sequence[Variable | Apply | Function | FunctionGraph],
depth: int = -1,
print_type: bool = False,
print_shape: bool = False,
file: Literal["str"] | TextIO | None = None,
id_type: IDTypesType = "CHAR",
stop_on_name: bool = False,
......@@ -98,6 +99,7 @@ def debugprint(
print_op_info: bool = False,
print_destroy_map: bool = False,
print_view_map: bool = False,
print_memory_map: bool = False,
print_fgraph_inputs: bool = False,
) -> str | TextIO:
r"""Print a graph as text.
......@@ -123,6 +125,8 @@ def debugprint(
Print graph to this depth (``-1`` for unlimited).
print_type
If ``True``, print the `Type`\s of each `Variable` in the graph.
print_shape
If ``True``, print the shape of each `Variable` in the graph.
file
When `file` extends `TextIO`, print to it; when `file` is
equal to ``"str"``, return a string; when `file` is ``None``, print to
......@@ -153,6 +157,8 @@ def debugprint(
Whether to print the `destroy_map`\s of printed objects
print_view_map
Whether to print the `view_map`\s of printed objects
print_memory_map
Whether to set both `print_destroy_map` and `print_view_map` to ``True``.
print_fgraph_inputs
Print the inputs of `FunctionGraph`\s.
......@@ -177,6 +183,10 @@ def debugprint(
if used_ids is None:
used_ids = dict()
if print_memory_map:
print_destroy_map = True
print_view_map = True
inputs_to_print = []
outputs_to_print = []
profile_list: list[Any | None] = []
......@@ -265,6 +275,7 @@ N.B.:
depth=depth,
done=done,
print_type=print_type,
print_shape=print_shape,
file=_file,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
......@@ -295,6 +306,7 @@ N.B.:
depth=depth,
done=done,
print_type=print_type,
print_shape=print_shape,
file=_file,
topo_order=topo_order,
id_type=id_type,
......@@ -365,6 +377,7 @@ N.B.:
depth=depth,
done=done,
print_type=print_type,
print_shape=print_shape,
file=_file,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
......@@ -387,6 +400,7 @@ N.B.:
depth=depth,
done=done,
print_type=print_type,
print_shape=print_shape,
file=_file,
id_type=id_type,
stop_on_name=stop_on_name,
......@@ -421,6 +435,7 @@ N.B.:
depth=depth,
done=done,
print_type=print_type,
print_shape=print_shape,
file=_file,
id_type=id_type,
stop_on_name=stop_on_name,
......@@ -452,6 +467,7 @@ def _debugprint(
depth: int = -1,
done: dict[Literal["output"] | Variable | Apply, str] | None = None,
print_type: bool = False,
print_shape: bool = False,
file: TextIO = sys.stdout,
print_destroy_map: bool = False,
print_view_map: bool = False,
......@@ -484,6 +500,8 @@ def _debugprint(
See `debugprint`.
print_type
See `debugprint`.
print_shape
See `debugprint`.
file
File-like object to which to print.
print_destroy_map
......@@ -532,6 +550,11 @@ def _debugprint(
else:
type_str = ""
if print_shape and hasattr(var.type, "shape"):
shape_str = f" shape={str(var.type.shape).replace('None', '?')}"
else:
shape_str = ""
if prefix_child is None:
prefix_child = prefix
......@@ -612,7 +635,7 @@ def _debugprint(
if is_inner_graph_header:
var_output = f"{prefix}{node.op}{id_str}{destroy_map_str}{view_map_str}{o}"
else:
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{shape_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
if print_op_info and node not in op_information:
op_information.update(op_debug_information(node.op, node))
......@@ -662,6 +685,7 @@ def _debugprint(
depth=depth - 1,
done=_done,
print_type=print_type,
print_shape=print_shape,
file=file,
topo_order=topo_order,
id_type=id_type,
......@@ -692,7 +716,7 @@ def _debugprint(
else:
data = ""
var_output = f"{prefix}{var}{id_str}{type_str}{data}"
var_output = f"{prefix}{var}{id_str}{type_str}{shape_str}{data}"
if print_op_info and var.owner and var.owner not in op_information:
op_information.update(op_debug_information(var.owner.op, var.owner))
......
......@@ -399,22 +399,13 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
else:
shape = self.shape
len_shape = len(shape)
def shape_str(s):
if s is None:
return "?"
else:
return str(s)
formatted_shape = ", ".join(shape_str(s) for s in shape)
if len_shape == 1:
formatted_shape += ","
formatted_shape = str(shape).replace("None", "?")
if len_shape > 2:
name = f"Tensor{len_shape}"
else:
name = ("Scalar", "Vector", "Matrix")[len_shape]
return f"{name}({self.dtype}, shape=({formatted_shape}))"
return f"{name}({self.dtype}, shape={formatted_shape})"
def __repr__(self):
return f"TensorType({self.dtype}, shape={self.shape})"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论