提交 2fee841e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add an option to print graph inputs in debugprint

上级 6965fdc1
...@@ -118,6 +118,7 @@ def debugprint( ...@@ -118,6 +118,7 @@ def debugprint(
print_op_info: bool = False, print_op_info: bool = False,
print_destroy_map: bool = False, print_destroy_map: bool = False,
print_view_map: bool = False, print_view_map: bool = False,
print_fgraph_inputs: bool = False,
) -> Union[str, IOBase]: ) -> Union[str, IOBase]:
r"""Print a computation graph as text to stdout or a file. r"""Print a computation graph as text to stdout or a file.
...@@ -175,6 +176,8 @@ def debugprint( ...@@ -175,6 +176,8 @@ def debugprint(
Whether to print the `destroy_map`\s of printed objects Whether to print the `destroy_map`\s of printed objects
print_view_map print_view_map
Whether to print the `view_map`\s of printed objects Whether to print the `view_map`\s of printed objects
print_fgraph_inputs
Print the inputs of `FunctionGraph`\s.
Returns Returns
------- -------
...@@ -197,7 +200,8 @@ def debugprint( ...@@ -197,7 +200,8 @@ def debugprint(
if used_ids is None: if used_ids is None:
used_ids = dict() used_ids = dict()
results_to_print = [] inputs_to_print = []
outputs_to_print = []
profile_list: List[Optional[Any]] = [] profile_list: List[Optional[Any]] = []
order: List[Optional[List[Apply]]] = [] # Toposort order: List[Optional[List[Apply]]] = [] # Toposort
smap: List[Optional[StorageMapType]] = [] # storage_map smap: List[Optional[StorageMapType]] = [] # storage_map
...@@ -209,17 +213,19 @@ def debugprint( ...@@ -209,17 +213,19 @@ def debugprint(
for obj in lobj: for obj in lobj:
if isinstance(obj, Variable): if isinstance(obj, Variable):
results_to_print.append(obj) outputs_to_print.append(obj)
profile_list.append(None) profile_list.append(None)
smap.append(None) smap.append(None)
order.append(None) order.append(None)
elif isinstance(obj, Apply): elif isinstance(obj, Apply):
results_to_print.extend(obj.outputs) outputs_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs]) profile_list.extend([None for item in obj.outputs])
smap.extend([None for item in obj.outputs]) smap.extend([None for item in obj.outputs])
order.extend([None for item in obj.outputs]) order.extend([None for item in obj.outputs])
elif isinstance(obj, Function): elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs) if print_fgraph_inputs:
inputs_to_print.extend(obj.maker.fgraph.inputs)
outputs_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs]) profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs])
if print_storage: if print_storage:
smap.extend([obj.vm.storage_map for item in obj.maker.fgraph.outputs]) smap.extend([obj.vm.storage_map for item in obj.maker.fgraph.outputs])
...@@ -228,7 +234,9 @@ def debugprint( ...@@ -228,7 +234,9 @@ def debugprint(
topo = obj.maker.fgraph.toposort() topo = obj.maker.fgraph.toposort()
order.extend([topo for item in obj.maker.fgraph.outputs]) order.extend([topo for item in obj.maker.fgraph.outputs])
elif isinstance(obj, FunctionGraph): elif isinstance(obj, FunctionGraph):
results_to_print.extend(obj.outputs) if print_fgraph_inputs:
inputs_to_print.extend(obj.inputs)
outputs_to_print.extend(obj.outputs)
profile_list.extend([getattr(obj, "profile", None) for item in obj.outputs]) profile_list.extend([getattr(obj, "profile", None) for item in obj.outputs])
smap.extend([getattr(obj, "storage_map", None) for item in obj.outputs]) smap.extend([getattr(obj, "storage_map", None) for item in obj.outputs])
topo = obj.toposort() topo = obj.toposort()
...@@ -236,7 +244,7 @@ def debugprint( ...@@ -236,7 +244,7 @@ def debugprint(
elif isinstance(obj, (int, float, np.ndarray)): elif isinstance(obj, (int, float, np.ndarray)):
print(obj, file=_file) print(obj, file=_file)
elif isinstance(obj, (In, Out)): elif isinstance(obj, (In, Out)):
results_to_print.append(obj.variable) outputs_to_print.append(obj.variable)
profile_list.append(None) profile_list.append(None)
smap.append(None) smap.append(None)
order.append(None) order.append(None)
...@@ -268,7 +276,26 @@ N.B.: ...@@ -268,7 +276,26 @@ N.B.:
op_information = {} op_information = {}
for r, p, s, o in zip(results_to_print, profile_list, smap, order): for r in inputs_to_print:
_debugprint(
r,
prefix="-",
depth=depth,
done=done,
print_type=print_type,
file=_file,
ids=ids,
inner_graph_ops=inner_graph_ops,
stop_on_name=stop_on_name,
used_ids=used_ids,
op_information=op_information,
parent_node=r.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
)
for r, p, s, o in zip(outputs_to_print, profile_list, smap, order):
if hasattr(r.owner, "op"): if hasattr(r.owner, "op"):
if isinstance(r.owner.op, HasInnerGraph) and r not in inner_graph_ops: if isinstance(r.owner.op, HasInnerGraph) and r not in inner_graph_ops:
...@@ -352,16 +379,39 @@ N.B.: ...@@ -352,16 +379,39 @@ N.B.:
print_view_map=print_view_map, print_view_map=print_view_map,
) )
for idx, i in enumerate(inner_outputs): if print_fgraph_inputs:
for inp in inner_inputs:
_debugprint(
r=inp,
prefix="-",
depth=depth,
done=done,
print_type=print_type,
file=_file,
ids=ids,
stop_on_name=stop_on_name,
inner_graph_ops=inner_graph_ops,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
parent_node=s.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=s.owner,
)
inner_to_outer_inputs = None
for out in inner_outputs:
if ( if (
isinstance(getattr(i.owner, "op", None), HasInnerGraph) isinstance(getattr(out.owner, "op", None), HasInnerGraph)
and i not in inner_graph_ops and out not in inner_graph_ops
): ):
inner_graph_ops.append(i) inner_graph_ops.append(out)
_debugprint( _debugprint(
r=i, r=out,
prefix=new_prefix, prefix=new_prefix,
depth=depth, depth=depth,
done=done, done=done,
...@@ -655,11 +705,10 @@ def _debugprint( ...@@ -655,11 +705,10 @@ def _debugprint(
var_output = f"{var_output} -> {outer_id_str}" var_output = f"{var_output} -> {outer_id_str}"
node_info = op_information.get(inner_graph_node) # TODO: This entire approach will only print `Op` info for two levels
if node_info and r in node_info: # of nesting.
var_output = f"{var_output} ({node_info[r]})" for node in dict.fromkeys([inner_graph_node, parent_node, r.owner]):
node_info = op_information.get(node)
node_info = op_information.get(parent_node) or op_information.get(r.owner)
if node_info and r in node_info: if node_info and r in node_info:
var_output = f"{var_output} ({node_info[r]})" var_output = f"{var_output} ({node_info[r]})"
......
...@@ -4,6 +4,7 @@ import pytest ...@@ -4,6 +4,7 @@ import pytest
import aesara import aesara
import aesara.tensor as at import aesara.tensor as at
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.printing import debugprint, pydot_imported, pydotprint from aesara.printing import debugprint, pydot_imported, pydotprint
from aesara.tensor.type import dvector, iscalar, scalar, vector from aesara.tensor.type import dvector, iscalar, scalar, vector
...@@ -184,15 +185,13 @@ def test_debugprint_nitsot(): ...@@ -184,15 +185,13 @@ def test_debugprint_nitsot():
@config.change_flags(floatX="float64") @config.change_flags(floatX="float64")
def test_debugprint_nested_scans(): def test_debugprint_nested_scans():
coefficients = dvector("coefficients") c = dvector("c")
max_coefficients_supported = 10 n = 10
k = iscalar("k") k = iscalar("k")
A = dvector("A") A = dvector("A")
# compute A**k
def compute_A_k(A, k): def compute_A_k(A, k):
# Symbolic description of the result
result, updates = aesara.scan( result, updates = aesara.scan(
fn=lambda prior_result, A: prior_result * A, fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A), outputs_info=at.ones_like(A),
...@@ -204,18 +203,13 @@ def test_debugprint_nested_scans(): ...@@ -204,18 +203,13 @@ def test_debugprint_nested_scans():
return A_k return A_k
# Generate the components of the polynomial
components, updates = aesara.scan( components, updates = aesara.scan(
fn=lambda coefficient, power, some_A, some_k: coefficient fn=lambda c, power, some_A, some_k: c * (compute_A_k(some_A, some_k) ** power),
* (compute_A_k(some_A, some_k) ** power),
outputs_info=None, outputs_info=None,
sequences=[coefficients, at.arange(max_coefficients_supported)], sequences=[c, at.arange(n)],
non_sequences=[A, k], non_sequences=[A, k],
) )
# Sum them up final_result = components.sum()
polynomial = components.sum()
final_result = polynomial
output_str = debugprint(final_result, file="str", print_op_info=True) output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n") lines = output_str.split("\n")
...@@ -225,8 +219,8 @@ def test_debugprint_nested_scans(): ...@@ -225,8 +219,8 @@ def test_debugprint_nested_scans():
|Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0) |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0)
| |Subtensor{int64} [id D] | |Subtensor{int64} [id D]
| | |Shape [id E] | | |Shape [id E]
| | | |Subtensor{int64::} [id F] 'coefficients[0:]' | | | |Subtensor{int64::} [id F] 'c[0:]'
| | | |coefficients [id G] | | | |c [id G]
| | | |ScalarConstant{0} [id H] | | | |ScalarConstant{0} [id H]
| | |ScalarConstant{0} [id I] | | |ScalarConstant{0} [id I]
| |Subtensor{int64} [id J] | |Subtensor{int64} [id J]
...@@ -239,7 +233,7 @@ def test_debugprint_nested_scans(): ...@@ -239,7 +233,7 @@ def test_debugprint_nested_scans():
| | |ScalarConstant{0} [id Q] | | |ScalarConstant{0} [id Q]
| |ScalarConstant{0} [id R] | |ScalarConstant{0} [id R]
|Subtensor{:int64:} [id S] (outer_in_seqs-0) |Subtensor{:int64:} [id S] (outer_in_seqs-0)
| |Subtensor{int64::} [id F] 'coefficients[0:]' | |Subtensor{int64::} [id F] 'c[0:]'
| |ScalarFromTensor [id T] | |ScalarFromTensor [id T]
| |Elemwise{scalar_minimum,no_inplace} [id C] | |Elemwise{scalar_minimum,no_inplace} [id C]
|Subtensor{:int64:} [id U] (outer_in_seqs-1) |Subtensor{:int64:} [id U] (outer_in_seqs-1)
...@@ -295,6 +289,97 @@ def test_debugprint_nested_scans(): ...@@ -295,6 +289,97 @@ def test_debugprint_nested_scans():
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
fg = FunctionGraph([c, k, A], [final_result])
output_str = debugprint(
fg, file="str", print_op_info=True, print_fgraph_inputs=True
)
lines = output_str.split("\n")
expected_output = """-c [id A]
-k [id B]
-A [id C]
Sum{acc_dtype=float64} [id D] 13
|for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0)
|Elemwise{scalar_minimum,no_inplace} [id F] 7 (outer_in_nit_sot-0)
| |Subtensor{int64} [id G] 6
| | |Shape [id H] 5
| | | |Subtensor{int64::} [id I] 'c[0:]' 4
| | | |c [id A]
| | | |ScalarConstant{0} [id J]
| | |ScalarConstant{0} [id K]
| |Subtensor{int64} [id L] 3
| |Shape [id M] 2
| | |Subtensor{int64::} [id N] 1
| | |ARange{dtype='int64'} [id O] 0
| | | |TensorConstant{0} [id P]
| | | |TensorConstant{10} [id Q]
| | | |TensorConstant{1} [id R]
| | |ScalarConstant{0} [id S]
| |ScalarConstant{0} [id T]
|Subtensor{:int64:} [id U] 11 (outer_in_seqs-0)
| |Subtensor{int64::} [id I] 'c[0:]' 4
| |ScalarFromTensor [id V] 10
| |Elemwise{scalar_minimum,no_inplace} [id F] 7
|Subtensor{:int64:} [id W] 9 (outer_in_seqs-1)
| |Subtensor{int64::} [id N] 1
| |ScalarFromTensor [id X] 8
| |Elemwise{scalar_minimum,no_inplace} [id F] 7
|Elemwise{scalar_minimum,no_inplace} [id F] 7 (outer_in_nit_sot-0)
|A [id C] (outer_in_non_seqs-0)
|k [id B] (outer_in_non_seqs-1)
Inner graphs:
for{cpu,scan_fn} [id E] (outer_out_nit_sot-0)
-*0-<TensorType(float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
-*1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
-*2-<TensorType(float64, (None,))> [id BA] -> [id C] (inner_in_non_seqs-0)
-*3-<TensorType(int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1)
>Elemwise{mul,no_inplace} [id BC] (inner_out_nit_sot-0)
> |InplaceDimShuffle{x} [id BD]
> | |*0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0)
> |Elemwise{pow,no_inplace} [id BE]
> |Subtensor{int64} [id BF]
> | |Subtensor{int64::} [id BG]
> | | |for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
> | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
> | | | |IncSubtensor{Set;:int64:} [id BI] (outer_in_sit_sot-0)
> | | | | |AllocEmpty{dtype='float64'} [id BJ]
> | | | | | |Elemwise{add,no_inplace} [id BK]
> | | | | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BL]
> | | | | | | |Shape [id BM]
> | | | | | | | |Rebroadcast{(0, False)} [id BN]
> | | | | | | | |InplaceDimShuffle{x,0} [id BO]
> | | | | | | | |Elemwise{second,no_inplace} [id BP]
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0)
> | | | | | | | |InplaceDimShuffle{x} [id BQ]
> | | | | | | | |TensorConstant{1.0} [id BR]
> | | | | | | |ScalarConstant{0} [id BS]
> | | | | | |Subtensor{int64} [id BT]
> | | | | | |Shape [id BU]
> | | | | | | |Rebroadcast{(0, False)} [id BN]
> | | | | | |ScalarConstant{1} [id BV]
> | | | | |Rebroadcast{(0, False)} [id BN]
> | | | | |ScalarFromTensor [id BW]
> | | | | |Subtensor{int64} [id BL]
> | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
> | | |ScalarConstant{1} [id BX]
> | |ScalarConstant{-1} [id BY]
> |InplaceDimShuffle{x} [id BZ]
> |*1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
-*0-<TensorType(float64, (None,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
-*1-<TensorType(float64, (None,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
>Elemwise{mul,no_inplace} [id CC] (inner_out_sit_sot-0)
> |*0-<TensorType(float64, (None,))> [id CA] (inner_in_sit_sot-0)
> |*1-<TensorType(float64, (None,))> [id CB] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@config.change_flags(floatX="float64") @config.change_flags(floatX="float64")
def test_debugprint_mitsot(): def test_debugprint_mitsot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论