提交 b4c1bbd5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Scan input/output type annotations to debugprint output

上级 c7b416e1
差异被折叠。
...@@ -76,6 +76,7 @@ from aesara.graph.utils import MissingInputError ...@@ -76,6 +76,7 @@ from aesara.graph.utils import MissingInputError
from aesara.link.c.basic import CLinker from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op from aesara.link.utils import raise_with_op
from aesara.printing import op_debug_information
from aesara.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new from aesara.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import minimum from aesara.tensor.math import minimum
...@@ -3303,3 +3304,44 @@ def profile_printer( ...@@ -3303,3 +3304,44 @@ def profile_printer(
), ),
file=file, file=file,
) )
@op_debug_information.register(Scan) # type: ignore
def _op_debug_information_Scan(op, node):
from typing import Sequence
from aesara.scan.utils import ScanArgs
extra_information = {}
inner_fn = getattr(op, "_fn", None)
if inner_fn:
inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs
else:
inner_inputs = op.inputs
inner_outputs = op.outputs
scan_args = ScanArgs(
node.inputs,
node.outputs,
inner_inputs,
inner_outputs,
node.op.info,
node.op.as_while,
clone=False,
)
for field_name in scan_args.field_names:
field_vars = getattr(scan_args, field_name)
if isinstance(field_vars, Sequence):
for i, var in enumerate(field_vars):
if isinstance(var, Sequence):
for j, sub_var in enumerate(var):
extra_information[sub_var] = f"{field_name}-{i}-{j}"
else:
extra_information[var] = f"{field_name}-{i}"
else:
extra_information[field_vars] = field_name
return {node: extra_information}
...@@ -544,17 +544,17 @@ def test_debugprint(): ...@@ -544,17 +544,17 @@ def test_debugprint():
output_str = debugprint(out, file="str") output_str = debugprint(out, file="str")
lines = output_str.split("\n") lines = output_str.split("\n")
exp_res = """OpFromGraph{inline=False} [id A] '' exp_res = """OpFromGraph{inline=False} [id A]
|x [id B] |x [id B]
|y [id C] |y [id C]
|z [id D] |z [id D]
Inner graphs: Inner graphs:
OpFromGraph{inline=False} [id A] '' OpFromGraph{inline=False} [id A]
>Elemwise{add,no_inplace} [id E] '' >Elemwise{add,no_inplace} [id E]
> |x [id F] > |x [id F]
> |Elemwise{mul,no_inplace} [id G] '' > |Elemwise{mul,no_inplace} [id G]
> |y [id H] > |y [id H]
> |z [id I] > |z [id I]
""" """
......
...@@ -141,11 +141,11 @@ def test_debugprint(): ...@@ -141,11 +141,11 @@ def test_debugprint():
reference = ( reference = (
"\n".join( "\n".join(
[ [
"Elemwise{add,no_inplace} [id 0] '' ", "Elemwise{add,no_inplace} [id 0]",
" |Elemwise{add,no_inplace} [id 1] 'C' ", " |Elemwise{add,no_inplace} [id 1] 'C'",
" | |A [id 2]", " | |A [id 2]",
" | |B [id 3]", " | |B [id 3]",
" |Elemwise{add,no_inplace} [id 4] '' ", " |Elemwise{add,no_inplace} [id 4]",
" |D [id 5]", " |D [id 5]",
" |E [id 6]", " |E [id 6]",
] ]
...@@ -167,11 +167,11 @@ def test_debugprint(): ...@@ -167,11 +167,11 @@ def test_debugprint():
reference = ( reference = (
"\n".join( "\n".join(
[ [
"Elemwise{add,no_inplace} [id A] '' ", "Elemwise{add,no_inplace} [id A]",
" |Elemwise{add,no_inplace} [id B] 'C' ", " |Elemwise{add,no_inplace} [id B] 'C'",
" | |A [id C]", " | |A [id C]",
" | |B [id D]", " | |B [id D]",
" |Elemwise{add,no_inplace} [id E] '' ", " |Elemwise{add,no_inplace} [id E]",
" |D [id F]", " |D [id F]",
" |E [id G]", " |E [id G]",
] ]
...@@ -193,9 +193,9 @@ def test_debugprint(): ...@@ -193,9 +193,9 @@ def test_debugprint():
reference = ( reference = (
"\n".join( "\n".join(
[ [
"Elemwise{add,no_inplace} [id A] '' ", "Elemwise{add,no_inplace} [id A]",
" |Elemwise{add,no_inplace} [id B] 'C' ", " |Elemwise{add,no_inplace} [id B] 'C'",
" |Elemwise{add,no_inplace} [id C] '' ", " |Elemwise{add,no_inplace} [id C]",
" |D [id D]", " |D [id D]",
" |E [id E]", " |E [id E]",
] ]
...@@ -217,13 +217,13 @@ def test_debugprint(): ...@@ -217,13 +217,13 @@ def test_debugprint():
reference = ( reference = (
"\n".join( "\n".join(
[ [
"Elemwise{add,no_inplace} '' ", "Elemwise{add,no_inplace}",
" |Elemwise{add,no_inplace} 'C' ", " |Elemwise{add,no_inplace} 'C'",
" | |A ", " | |A",
" | |B ", " | |B",
" |Elemwise{add,no_inplace} '' ", " |Elemwise{add,no_inplace}",
" |D ", " |D",
" |E ", " |E",
] ]
) )
+ "\n" + "\n"
...@@ -238,15 +238,14 @@ def test_debugprint(): ...@@ -238,15 +238,14 @@ def test_debugprint():
s = StringIO() s = StringIO()
debugprint(g, file=s, ids="", print_storage=True) debugprint(g, file=s, ids="", print_storage=True)
s = s.getvalue() s = s.getvalue()
# The additional white space are needed!
reference = ( reference = (
"\n".join( "\n".join(
[ [
"Elemwise{add,no_inplace} '' 0 [None]", "Elemwise{add,no_inplace} 0 [None]",
" |A [None]", " |A [None]",
" |B [None]", " |B [None]",
" |D [None]", " |D [None]",
" |E [None]", " |E [None]",
] ]
) )
+ "\n" + "\n"
...@@ -269,8 +268,8 @@ def test_debugprint_ids(): ...@@ -269,8 +268,8 @@ def test_debugprint_ids():
debugprint(e_at, ids="auto", file=s) debugprint(e_at, ids="auto", file=s)
s = s.getvalue() s = s.getvalue()
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}] '' exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}]
|dot [id {d_at.auto_name}] '' |dot [id {d_at.auto_name}]
| |<TensorType(float64, (None, None))> [id {b_at.auto_name}] | |<TensorType(float64, (None, None))> [id {b_at.auto_name}]
| |<TensorType(float64, (None,))> [id {a_at.auto_name}] | |<TensorType(float64, (None,))> [id {a_at.auto_name}]
|<TensorType(float64, (None,))> [id {a_at.auto_name}] |<TensorType(float64, (None,))> [id {a_at.auto_name}]
...@@ -306,13 +305,13 @@ def test_debugprint_inner_graph(): ...@@ -306,13 +305,13 @@ def test_debugprint_inner_graph():
output_str = debugprint(out, file="str") output_str = debugprint(out, file="str")
lines = output_str.split("\n") lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] '' exp_res = """MyInnerGraphOp [id A]
|3 [id B] |3 [id B]
|4 [id C] |4 [id C]
Inner graphs: Inner graphs:
MyInnerGraphOp [id A] '' MyInnerGraphOp [id A]
>op2 [id D] 'igo1' >op2 [id D] 'igo1'
> |4 [id E] > |4 [id E]
> |5 [id F] > |5 [id F]
...@@ -330,17 +329,17 @@ MyInnerGraphOp [id A] '' ...@@ -330,17 +329,17 @@ MyInnerGraphOp [id A] ''
output_str = debugprint(out_2, file="str") output_str = debugprint(out_2, file="str")
lines = output_str.split("\n") lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] '' exp_res = """MyInnerGraphOp [id A]
|5 [id B] |5 [id B]
Inner graphs: Inner graphs:
MyInnerGraphOp [id A] '' MyInnerGraphOp [id A]
>MyInnerGraphOp [id C] '' >MyInnerGraphOp [id C]
> |3 [id D] > |3 [id D]
> |4 [id E] > |4 [id E]
MyInnerGraphOp [id C] '' MyInnerGraphOp [id C]
>op2 [id F] 'igo1' >op2 [id F] 'igo1'
> |4 [id G] > |4 [id G]
> |5 [id H] > |5 [id H]
...@@ -371,13 +370,13 @@ def test_get_var_by_id(): ...@@ -371,13 +370,13 @@ def test_get_var_by_id():
# op1 [id A] 'o1' # op1 [id A] 'o1'
# |1 [id B] # |1 [id B]
# |2 [id C] # |2 [id C]
# MyInnerGraphOp [id D] '' # MyInnerGraphOp [id D]
# |3 [id E] # |3 [id E]
# |op1 [id A] 'o1' # |op1 [id A] 'o1'
# #
# Inner graphs: # Inner graphs:
# #
# MyInnerGraphOp [id D] '' # MyInnerGraphOp [id D]
# >op2 [id F] 'igo1' # >op2 [id F] 'igo1'
# > |4 [id G] # > |4 [id G]
# > |5 [id H] # > |5 [id H]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论