提交 2fee841e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add an option to print graph inputs in debugprint

上级 6965fdc1
......@@ -118,6 +118,7 @@ def debugprint(
print_op_info: bool = False,
print_destroy_map: bool = False,
print_view_map: bool = False,
print_fgraph_inputs: bool = False,
) -> Union[str, IOBase]:
r"""Print a computation graph as text to stdout or a file.
......@@ -175,6 +176,8 @@ def debugprint(
Whether to print the `destroy_map`\s of printed objects
print_view_map
Whether to print the `view_map`\s of printed objects
print_fgraph_inputs
Print the inputs of `FunctionGraph`\s.
Returns
-------
......@@ -197,7 +200,8 @@ def debugprint(
if used_ids is None:
used_ids = dict()
results_to_print = []
inputs_to_print = []
outputs_to_print = []
profile_list: List[Optional[Any]] = []
order: List[Optional[List[Apply]]] = [] # Toposort
smap: List[Optional[StorageMapType]] = [] # storage_map
......@@ -209,17 +213,19 @@ def debugprint(
for obj in lobj:
if isinstance(obj, Variable):
results_to_print.append(obj)
outputs_to_print.append(obj)
profile_list.append(None)
smap.append(None)
order.append(None)
elif isinstance(obj, Apply):
results_to_print.extend(obj.outputs)
outputs_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs])
smap.extend([None for item in obj.outputs])
order.extend([None for item in obj.outputs])
elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs)
if print_fgraph_inputs:
inputs_to_print.extend(obj.maker.fgraph.inputs)
outputs_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs])
if print_storage:
smap.extend([obj.vm.storage_map for item in obj.maker.fgraph.outputs])
......@@ -228,7 +234,9 @@ def debugprint(
topo = obj.maker.fgraph.toposort()
order.extend([topo for item in obj.maker.fgraph.outputs])
elif isinstance(obj, FunctionGraph):
results_to_print.extend(obj.outputs)
if print_fgraph_inputs:
inputs_to_print.extend(obj.inputs)
outputs_to_print.extend(obj.outputs)
profile_list.extend([getattr(obj, "profile", None) for item in obj.outputs])
smap.extend([getattr(obj, "storage_map", None) for item in obj.outputs])
topo = obj.toposort()
......@@ -236,7 +244,7 @@ def debugprint(
elif isinstance(obj, (int, float, np.ndarray)):
print(obj, file=_file)
elif isinstance(obj, (In, Out)):
results_to_print.append(obj.variable)
outputs_to_print.append(obj.variable)
profile_list.append(None)
smap.append(None)
order.append(None)
......@@ -268,7 +276,26 @@ N.B.:
op_information = {}
for r, p, s, o in zip(results_to_print, profile_list, smap, order):
for r in inputs_to_print:
_debugprint(
r,
prefix="-",
depth=depth,
done=done,
print_type=print_type,
file=_file,
ids=ids,
inner_graph_ops=inner_graph_ops,
stop_on_name=stop_on_name,
used_ids=used_ids,
op_information=op_information,
parent_node=r.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
)
for r, p, s, o in zip(outputs_to_print, profile_list, smap, order):
if hasattr(r.owner, "op"):
if isinstance(r.owner.op, HasInnerGraph) and r not in inner_graph_ops:
......@@ -352,16 +379,39 @@ N.B.:
print_view_map=print_view_map,
)
for idx, i in enumerate(inner_outputs):
if print_fgraph_inputs:
for inp in inner_inputs:
_debugprint(
r=inp,
prefix="-",
depth=depth,
done=done,
print_type=print_type,
file=_file,
ids=ids,
stop_on_name=stop_on_name,
inner_graph_ops=inner_graph_ops,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
parent_node=s.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=s.owner,
)
inner_to_outer_inputs = None
for out in inner_outputs:
if (
isinstance(getattr(i.owner, "op", None), HasInnerGraph)
and i not in inner_graph_ops
isinstance(getattr(out.owner, "op", None), HasInnerGraph)
and out not in inner_graph_ops
):
inner_graph_ops.append(i)
inner_graph_ops.append(out)
_debugprint(
r=i,
r=out,
prefix=new_prefix,
depth=depth,
done=done,
......@@ -655,14 +705,13 @@ def _debugprint(
var_output = f"{var_output} -> {outer_id_str}"
node_info = op_information.get(inner_graph_node)
# TODO: This entire approach will only print `Op` info for two levels
# of nesting.
for node in dict.fromkeys([inner_graph_node, parent_node, r.owner]):
node_info = op_information.get(node)
if node_info and r in node_info:
var_output = f"{var_output} ({node_info[r]})"
node_info = op_information.get(parent_node) or op_information.get(r.owner)
if node_info and r in node_info:
var_output = f"{var_output} ({node_info[r]})"
print(var_output, file=file)
return file
......
......@@ -4,6 +4,7 @@ import pytest
import aesara
import aesara.tensor as at
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.printing import debugprint, pydot_imported, pydotprint
from aesara.tensor.type import dvector, iscalar, scalar, vector
......@@ -184,15 +185,13 @@ def test_debugprint_nitsot():
@config.change_flags(floatX="float64")
def test_debugprint_nested_scans():
coefficients = dvector("coefficients")
max_coefficients_supported = 10
c = dvector("c")
n = 10
k = iscalar("k")
A = dvector("A")
# compute A**k
def compute_A_k(A, k):
# Symbolic description of the result
result, updates = aesara.scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A),
......@@ -204,18 +203,13 @@ def test_debugprint_nested_scans():
return A_k
# Generate the components of the polynomial
components, updates = aesara.scan(
fn=lambda coefficient, power, some_A, some_k: coefficient
* (compute_A_k(some_A, some_k) ** power),
fn=lambda c, power, some_A, some_k: c * (compute_A_k(some_A, some_k) ** power),
outputs_info=None,
sequences=[coefficients, at.arange(max_coefficients_supported)],
sequences=[c, at.arange(n)],
non_sequences=[A, k],
)
# Sum them up
polynomial = components.sum()
final_result = polynomial
final_result = components.sum()
output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n")
......@@ -225,8 +219,8 @@ def test_debugprint_nested_scans():
|Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0)
| |Subtensor{int64} [id D]
| | |Shape [id E]
| | | |Subtensor{int64::} [id F] 'coefficients[0:]'
| | | |coefficients [id G]
| | | |Subtensor{int64::} [id F] 'c[0:]'
| | | |c [id G]
| | | |ScalarConstant{0} [id H]
| | |ScalarConstant{0} [id I]
| |Subtensor{int64} [id J]
......@@ -239,7 +233,7 @@ def test_debugprint_nested_scans():
| | |ScalarConstant{0} [id Q]
| |ScalarConstant{0} [id R]
|Subtensor{:int64:} [id S] (outer_in_seqs-0)
| |Subtensor{int64::} [id F] 'coefficients[0:]'
| |Subtensor{int64::} [id F] 'c[0:]'
| |ScalarFromTensor [id T]
| |Elemwise{scalar_minimum,no_inplace} [id C]
|Subtensor{:int64:} [id U] (outer_in_seqs-1)
......@@ -295,6 +289,97 @@ def test_debugprint_nested_scans():
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
fg = FunctionGraph([c, k, A], [final_result])
output_str = debugprint(
fg, file="str", print_op_info=True, print_fgraph_inputs=True
)
lines = output_str.split("\n")
expected_output = """-c [id A]
-k [id B]
-A [id C]
Sum{acc_dtype=float64} [id D] 13
|for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0)
|Elemwise{scalar_minimum,no_inplace} [id F] 7 (outer_in_nit_sot-0)
| |Subtensor{int64} [id G] 6
| | |Shape [id H] 5
| | | |Subtensor{int64::} [id I] 'c[0:]' 4
| | | |c [id A]
| | | |ScalarConstant{0} [id J]
| | |ScalarConstant{0} [id K]
| |Subtensor{int64} [id L] 3
| |Shape [id M] 2
| | |Subtensor{int64::} [id N] 1
| | |ARange{dtype='int64'} [id O] 0
| | | |TensorConstant{0} [id P]
| | | |TensorConstant{10} [id Q]
| | | |TensorConstant{1} [id R]
| | |ScalarConstant{0} [id S]
| |ScalarConstant{0} [id T]
|Subtensor{:int64:} [id U] 11 (outer_in_seqs-0)
| |Subtensor{int64::} [id I] 'c[0:]' 4
| |ScalarFromTensor [id V] 10
| |Elemwise{scalar_minimum,no_inplace} [id F] 7
|Subtensor{:int64:} [id W] 9 (outer_in_seqs-1)
| |Subtensor{int64::} [id N] 1
| |ScalarFromTensor [id X] 8
| |Elemwise{scalar_minimum,no_inplace} [id F] 7
|Elemwise{scalar_minimum,no_inplace} [id F] 7 (outer_in_nit_sot-0)
|A [id C] (outer_in_non_seqs-0)
|k [id B] (outer_in_non_seqs-1)
Inner graphs:
for{cpu,scan_fn} [id E] (outer_out_nit_sot-0)
-*0-<TensorType(float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
-*1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
-*2-<TensorType(float64, (None,))> [id BA] -> [id C] (inner_in_non_seqs-0)
-*3-<TensorType(int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1)
>Elemwise{mul,no_inplace} [id BC] (inner_out_nit_sot-0)
> |InplaceDimShuffle{x} [id BD]
> | |*0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0)
> |Elemwise{pow,no_inplace} [id BE]
> |Subtensor{int64} [id BF]
> | |Subtensor{int64::} [id BG]
> | | |for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
> | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
> | | | |IncSubtensor{Set;:int64:} [id BI] (outer_in_sit_sot-0)
> | | | | |AllocEmpty{dtype='float64'} [id BJ]
> | | | | | |Elemwise{add,no_inplace} [id BK]
> | | | | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BL]
> | | | | | | |Shape [id BM]
> | | | | | | | |Rebroadcast{(0, False)} [id BN]
> | | | | | | | |InplaceDimShuffle{x,0} [id BO]
> | | | | | | | |Elemwise{second,no_inplace} [id BP]
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0)
> | | | | | | | |InplaceDimShuffle{x} [id BQ]
> | | | | | | | |TensorConstant{1.0} [id BR]
> | | | | | | |ScalarConstant{0} [id BS]
> | | | | | |Subtensor{int64} [id BT]
> | | | | | |Shape [id BU]
> | | | | | | |Rebroadcast{(0, False)} [id BN]
> | | | | | |ScalarConstant{1} [id BV]
> | | | | |Rebroadcast{(0, False)} [id BN]
> | | | | |ScalarFromTensor [id BW]
> | | | | |Subtensor{int64} [id BL]
> | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
> | | |ScalarConstant{1} [id BX]
> | |ScalarConstant{-1} [id BY]
> |InplaceDimShuffle{x} [id BZ]
> |*1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
-*0-<TensorType(float64, (None,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
-*1-<TensorType(float64, (None,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
>Elemwise{mul,no_inplace} [id CC] (inner_out_sit_sot-0)
> |*0-<TensorType(float64, (None,))> [id CA] (inner_in_sit_sot-0)
> |*1-<TensorType(float64, (None,))> [id CB] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@config.change_flags(floatX="float64")
def test_debugprint_mitsot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论