提交 b4c1bbd5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Scan input/output type annotations to debugprint output

上级 c7b416e1
差异被折叠。
......@@ -76,6 +76,7 @@ from aesara.graph.utils import MissingInputError
from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op
from aesara.printing import op_debug_information
from aesara.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import minimum
......@@ -3303,3 +3304,44 @@ def profile_printer(
),
file=file,
)
@op_debug_information.register(Scan) # type: ignore
def _op_debug_information_Scan(op, node):
from typing import Sequence
from aesara.scan.utils import ScanArgs
extra_information = {}
inner_fn = getattr(op, "_fn", None)
if inner_fn:
inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs
else:
inner_inputs = op.inputs
inner_outputs = op.outputs
scan_args = ScanArgs(
node.inputs,
node.outputs,
inner_inputs,
inner_outputs,
node.op.info,
node.op.as_while,
clone=False,
)
for field_name in scan_args.field_names:
field_vars = getattr(scan_args, field_name)
if isinstance(field_vars, Sequence):
for i, var in enumerate(field_vars):
if isinstance(var, Sequence):
for j, sub_var in enumerate(var):
extra_information[sub_var] = f"{field_name}-{i}-{j}"
else:
extra_information[var] = f"{field_name}-{i}"
else:
extra_information[field_vars] = field_name
return {node: extra_information}
......@@ -544,17 +544,17 @@ def test_debugprint():
output_str = debugprint(out, file="str")
lines = output_str.split("\n")
exp_res = """OpFromGraph{inline=False} [id A] ''
exp_res = """OpFromGraph{inline=False} [id A]
|x [id B]
|y [id C]
|z [id D]
Inner graphs:
OpFromGraph{inline=False} [id A] ''
>Elemwise{add,no_inplace} [id E] ''
OpFromGraph{inline=False} [id A]
>Elemwise{add,no_inplace} [id E]
> |x [id F]
> |Elemwise{mul,no_inplace} [id G] ''
> |Elemwise{mul,no_inplace} [id G]
> |y [id H]
> |z [id I]
"""
......
......@@ -141,11 +141,11 @@ def test_debugprint():
reference = (
"\n".join(
[
"Elemwise{add,no_inplace} [id 0] '' ",
" |Elemwise{add,no_inplace} [id 1] 'C' ",
"Elemwise{add,no_inplace} [id 0]",
" |Elemwise{add,no_inplace} [id 1] 'C'",
" | |A [id 2]",
" | |B [id 3]",
" |Elemwise{add,no_inplace} [id 4] '' ",
" |Elemwise{add,no_inplace} [id 4]",
" |D [id 5]",
" |E [id 6]",
]
......@@ -167,11 +167,11 @@ def test_debugprint():
reference = (
"\n".join(
[
"Elemwise{add,no_inplace} [id A] '' ",
" |Elemwise{add,no_inplace} [id B] 'C' ",
"Elemwise{add,no_inplace} [id A]",
" |Elemwise{add,no_inplace} [id B] 'C'",
" | |A [id C]",
" | |B [id D]",
" |Elemwise{add,no_inplace} [id E] '' ",
" |Elemwise{add,no_inplace} [id E]",
" |D [id F]",
" |E [id G]",
]
......@@ -193,9 +193,9 @@ def test_debugprint():
reference = (
"\n".join(
[
"Elemwise{add,no_inplace} [id A] '' ",
" |Elemwise{add,no_inplace} [id B] 'C' ",
" |Elemwise{add,no_inplace} [id C] '' ",
"Elemwise{add,no_inplace} [id A]",
" |Elemwise{add,no_inplace} [id B] 'C'",
" |Elemwise{add,no_inplace} [id C]",
" |D [id D]",
" |E [id E]",
]
......@@ -217,13 +217,13 @@ def test_debugprint():
reference = (
"\n".join(
[
"Elemwise{add,no_inplace} '' ",
" |Elemwise{add,no_inplace} 'C' ",
" | |A ",
" | |B ",
" |Elemwise{add,no_inplace} '' ",
" |D ",
" |E ",
"Elemwise{add,no_inplace}",
" |Elemwise{add,no_inplace} 'C'",
" | |A",
" | |B",
" |Elemwise{add,no_inplace}",
" |D",
" |E",
]
)
+ "\n"
......@@ -238,15 +238,14 @@ def test_debugprint():
s = StringIO()
debugprint(g, file=s, ids="", print_storage=True)
s = s.getvalue()
# The additional white space are needed!
reference = (
"\n".join(
[
"Elemwise{add,no_inplace} '' 0 [None]",
" |A [None]",
" |B [None]",
" |D [None]",
" |E [None]",
"Elemwise{add,no_inplace} 0 [None]",
" |A [None]",
" |B [None]",
" |D [None]",
" |E [None]",
]
)
+ "\n"
......@@ -269,8 +268,8 @@ def test_debugprint_ids():
debugprint(e_at, ids="auto", file=s)
s = s.getvalue()
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}] ''
|dot [id {d_at.auto_name}] ''
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}]
|dot [id {d_at.auto_name}]
| |<TensorType(float64, (None, None))> [id {b_at.auto_name}]
| |<TensorType(float64, (None,))> [id {a_at.auto_name}]
|<TensorType(float64, (None,))> [id {a_at.auto_name}]
......@@ -306,13 +305,13 @@ def test_debugprint_inner_graph():
output_str = debugprint(out, file="str")
lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] ''
exp_res = """MyInnerGraphOp [id A]
|3 [id B]
|4 [id C]
Inner graphs:
MyInnerGraphOp [id A] ''
MyInnerGraphOp [id A]
>op2 [id D] 'igo1'
> |4 [id E]
> |5 [id F]
......@@ -330,17 +329,17 @@ MyInnerGraphOp [id A] ''
output_str = debugprint(out_2, file="str")
lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] ''
exp_res = """MyInnerGraphOp [id A]
|5 [id B]
Inner graphs:
MyInnerGraphOp [id A] ''
>MyInnerGraphOp [id C] ''
MyInnerGraphOp [id A]
>MyInnerGraphOp [id C]
> |3 [id D]
> |4 [id E]
MyInnerGraphOp [id C] ''
MyInnerGraphOp [id C]
>op2 [id F] 'igo1'
> |4 [id G]
> |5 [id H]
......@@ -371,13 +370,13 @@ def test_get_var_by_id():
# op1 [id A] 'o1'
# |1 [id B]
# |2 [id C]
# MyInnerGraphOp [id D] ''
# MyInnerGraphOp [id D]
# |3 [id E]
# |op1 [id A] 'o1'
#
# Inner graphs:
#
# MyInnerGraphOp [id D] ''
# MyInnerGraphOp [id D]
# >op2 [id F] 'igo1'
# > |4 [id G]
# > |5 [id H]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论