提交 0f9f93c8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a function that gets variables in a graph based on debugprint IDs

上级 724561b2
...@@ -11,7 +11,7 @@ import sys ...@@ -11,7 +11,7 @@ import sys
from copy import copy from copy import copy
from functools import reduce from functools import reduce
from io import IOBase, StringIO from io import IOBase, StringIO
from typing import Dict, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
import numpy as np import numpy as np
...@@ -552,10 +552,11 @@ def _debugprint( ...@@ -552,10 +552,11 @@ def _debugprint(
id_str = get_id_str(r) id_str = get_id_str(r)
outer_r = inner_to_outer_inputs[r] outer_r = inner_to_outer_inputs[r]
if hasattr(outer_r.owner, "op"): if outer_r.owner:
outer_id_str = get_id_str(outer_r.owner) outer_id_str = get_id_str(outer_r.owner)
else: else:
outer_id_str = get_id_str(outer_r) outer_id_str = get_id_str(outer_r)
print( print(
f"{prefix}{r} {id_str}{type_str} -> {outer_id_str}", f"{prefix}{r} {id_str}{type_str} -> {outer_id_str}",
file=file, file=file,
...@@ -1643,3 +1644,38 @@ def hex_digest(x): ...@@ -1643,3 +1644,38 @@ def hex_digest(x):
rval = rval + "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]" rval = rval + "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
rval = rval + "|shape=[" + ",".join(str(s) for s in x.shape) + "]" rval = rval + "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
return rval return rval
def get_node_by_id(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR"
) -> Optional[Union[Variable, Apply]]:
r"""Get `Apply` nodes or `Variable`\s in a graph using their `debugprint` IDs.
Parameters
----------
graphs:
The graph, or graphs, to search.
target_var_id:
The name to search for.
ids:
The ID scheme to use (see `debugprint.`).
Returns
-------
The `Apply`/`Variable` matching `target_var_id` or ``None``.
"""
from aesara.printing import debugprint
if isinstance(graphs, Variable):
graphs = (graphs,)
used_ids = dict()
_ = debugprint(graphs, file="str", used_ids=used_ids, ids=ids)
id_to_node = {v: k for k, v in used_ids.items()}
id_str = f"[id {target_var_id}]"
return id_to_node.get(id_str, None)
...@@ -9,6 +9,7 @@ import pytest ...@@ -9,6 +9,7 @@ import pytest
import aesara import aesara
from aesara.printing import ( from aesara.printing import (
debugprint, debugprint,
get_node_by_id,
min_informative_str, min_informative_str,
pp, pp,
pydot_imported, pydot_imported,
...@@ -344,3 +345,48 @@ MyInnerGraphOp [id C] '' ...@@ -344,3 +345,48 @@ MyInnerGraphOp [id C] ''
for exp_line, res_line in zip(exp_res.split("\n"), lines): for exp_line, res_line in zip(exp_res.split("\n"), lines):
assert exp_line.strip() == res_line.strip() assert exp_line.strip() == res_line.strip()
def test_get_var_by_id():
r1, r2 = MyVariable("v1"), MyVariable("v2")
o1 = MyOp("op1")(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable("v4")
igo_in_2 = MyVariable("v5")
igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2)
igo_out_1.name = "igo1"
igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1])
r3 = MyVariable("v3")
o2 = igo(r3, o1)
# import aesara; aesara.dprint([o1, o2])
# op1 [id A] 'o1'
# |1 [id B]
# |2 [id C]
# MyInnerGraphOp [id D] ''
# |3 [id E]
# |op1 [id A] 'o1'
#
# Inner graphs:
#
# MyInnerGraphOp [id D] ''
# >op2 [id F] 'igo1'
# > |4 [id G]
# > |5 [id H]
res = get_node_by_id(o1, "blah")
assert res is None
res = get_node_by_id([o1, o2], "C")
assert res == r2
res = get_node_by_id([o1, o2], "F")
assert res == igo_out_1.owner
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论