提交 29153adc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create a function that gets variables in a graph by name

上级 d4696e6b
......@@ -1648,3 +1648,40 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return False
return True
def get_var_by_name(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR"
) -> Tuple[Variable]:
r"""Get variables in a graph using their names.
Parameters
----------
graphs:
The graph, or graphs, to search.
target_var_id:
The name to match against either ``Variable.name`` or
``Variable.auto_name``.
Returns
-------
A ``tuple`` containing all the `Variable`\s that match `target_var_id`.
"""
from aesara.graph.op import HasInnerGraph
def expand(r):
if r.owner:
res = r.owner.inputs
if isinstance(r.owner.op, HasInnerGraph):
res.extend(r.owner.op.inner_outputs)
return res
results = ()
for var in walk(graphs, expand, False):
if target_var_id == var.name or target_var_id == var.auto_name:
results += (var,)
return results
......@@ -15,6 +15,7 @@ from aesara.graph.basic import (
clone,
equal_computations,
general_toposort,
get_var_by_name,
graph_inputs,
io_toposort,
is_in_ancestors,
......@@ -23,7 +24,7 @@ from aesara.graph.basic import (
vars_between,
walk,
)
from aesara.graph.op import Op
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.type import Type
from aesara.tensor.math import max_and_argmax
from aesara.tensor.type import TensorType, iscalars, matrix, scalars
......@@ -70,6 +71,36 @@ class MyOp(Op):
MyOp = MyOp()
class MyInnerGraphOp(Op, HasInnerGraph):
__props__ = ()
def __init__(self, inner_inputs, inner_outputs):
self._inner_inputs = inner_inputs
self._inner_outputs = inner_outputs
def make_node(self, *inputs):
for input in inputs:
assert isinstance(input, Variable)
assert isinstance(input.type, MyType)
outputs = [MyVariable(sum(input.type.thingy for input in inputs))]
return Apply(self, list(inputs), outputs)
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
@property
def fn(self):
raise NotImplementedError("No Python implementation available.")
@property
def inner_inputs(self):
return self._inner_inputs
@property
def inner_outputs(self):
return self._inner_outputs
class X:
def leaf_formatter(self, leaf):
return str(leaf.type)
......@@ -472,3 +503,37 @@ def test_io_connection_pattern():
@pytest.mark.xfail(reason="Not implemented")
def test_view_roots():
raise AssertionError()
def test_get_var_by_name():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable(4)
igo_in_2 = MyVariable(5)
igo_out_1 = MyOp(igo_in_1, igo_in_2)
igo_out_1.name = "igo1"
igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1])
o2 = igo(r3, o1)
o2.name = "o1"
res = get_var_by_name([o1, o2], "blah")
assert res == ()
res = get_var_by_name([o1, o2], "o1")
assert set(res) == {o1, o2}
(res,) = get_var_by_name([o1, o2], o1.auto_name)
assert res == o1
(res,) = get_var_by_name([o1, o2], "igo1")
assert res == igo_out_1
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论