提交 0f9f93c8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a function that gets variables in a graph based on debugprint IDs

上级 724561b2
......@@ -11,7 +11,7 @@ import sys
from copy import copy
from functools import reduce
from io import IOBase, StringIO
from typing import Dict, List, Optional, Union
from typing import Dict, Iterable, List, Optional, Union
import numpy as np
......@@ -552,10 +552,11 @@ def _debugprint(
id_str = get_id_str(r)
outer_r = inner_to_outer_inputs[r]
if hasattr(outer_r.owner, "op"):
if outer_r.owner:
outer_id_str = get_id_str(outer_r.owner)
else:
outer_id_str = get_id_str(outer_r)
print(
f"{prefix}{r} {id_str}{type_str} -> {outer_id_str}",
file=file,
......@@ -1643,3 +1644,38 @@ def hex_digest(x):
rval = rval + "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
rval = rval + "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
return rval
def get_node_by_id(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR"
) -> Optional[Union[Variable, Apply]]:
r"""Get `Apply` nodes or `Variable`\s in a graph using their `debugprint` IDs.
Parameters
----------
graphs:
The graph, or graphs, to search.
target_var_id:
The name to search for.
ids:
The ID scheme to use (see `debugprint.`).
Returns
-------
The `Apply`/`Variable` matching `target_var_id` or ``None``.
"""
from aesara.printing import debugprint
if isinstance(graphs, Variable):
graphs = (graphs,)
used_ids = dict()
_ = debugprint(graphs, file="str", used_ids=used_ids, ids=ids)
id_to_node = {v: k for k, v in used_ids.items()}
id_str = f"[id {target_var_id}]"
return id_to_node.get(id_str, None)
......@@ -9,6 +9,7 @@ import pytest
import aesara
from aesara.printing import (
debugprint,
get_node_by_id,
min_informative_str,
pp,
pydot_imported,
......@@ -344,3 +345,48 @@ MyInnerGraphOp [id C] ''
for exp_line, res_line in zip(exp_res.split("\n"), lines):
assert exp_line.strip() == res_line.strip()
def test_get_var_by_id():
r1, r2 = MyVariable("v1"), MyVariable("v2")
o1 = MyOp("op1")(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable("v4")
igo_in_2 = MyVariable("v5")
igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2)
igo_out_1.name = "igo1"
igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1])
r3 = MyVariable("v3")
o2 = igo(r3, o1)
# import aesara; aesara.dprint([o1, o2])
# op1 [id A] 'o1'
# |1 [id B]
# |2 [id C]
# MyInnerGraphOp [id D] ''
# |3 [id E]
# |op1 [id A] 'o1'
#
# Inner graphs:
#
# MyInnerGraphOp [id D] ''
# >op2 [id F] 'igo1'
# > |4 [id G]
# > |5 [id H]
res = get_node_by_id(o1, "blah")
assert res is None
res = get_node_by_id([o1, o2], "C")
assert res == r2
res = get_node_by_id([o1, o2], "F")
assert res == igo_out_1.owner
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论