提交 4fd52e4e authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Create opt_util `get_clients_at_depth`

上级 7db04c9c
import copy import copy
from typing import Sequence, Union from typing import Generator, Sequence, Union
import aesara import aesara
from aesara.graph.basic import Variable, equal_computations, graph_inputs, vars_between from aesara.graph.basic import (
Apply,
Variable,
equal_computations,
graph_inputs,
vars_between,
)
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
...@@ -198,3 +204,17 @@ def is_same_graph(var1, var2, givens=None): ...@@ -198,3 +204,17 @@ def is_same_graph(var1, var2, givens=None):
rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys) rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys)
assert rval2 == rval1 assert rval2 == rval1
return rval1 return rval1
def get_clients_at_depth(
fgraph: FunctionGraph, node: Apply, depth: int
) -> Generator[Apply, None, None]:
"""Yields node clients at given depth."""
for node in node.outputs:
if depth > 0:
for out_node, _ in fgraph.clients[node]:
if out_node == "output":
continue
yield from get_clients_at_depth(fgraph, out_node, depth - 1)
else:
yield node.owner
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论