提交 71618f60 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Do not recompute toposort in every iteration of FusionOptimizer

It's not really needed as we never expand on the new nodes
上级 9ef575b7
......@@ -625,10 +625,10 @@ class FusionOptimizer(GraphRewriter):
def find_fuseable_subgraph(
*,
fg: FunctionGraph,
visited_nodes: set[Apply],
fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING,
toposort_index: dict[Apply, int],
) -> tuple[list[Variable], list[Variable]]:
KT = TypeVar("KT")
VT = TypeVar("VT", list, set)
......@@ -648,8 +648,7 @@ class FusionOptimizer(GraphRewriter):
for a in ancestors(variables, blockers=stop_search_at)
)
toposort = fg.toposort()
for starting_node in toposort:
for starting_node in toposort_index:
if starting_node in visited_nodes:
continue
......@@ -791,7 +790,7 @@ class FusionOptimizer(GraphRewriter):
and inp.owner not in visited_nodes
)
),
key=lambda inp: toposort.index(inp.owner),
key=lambda inp: toposort_index[inp.owner],
reverse=True,
):
fuseable_nodes_to_visit.appendleft(inp.owner)
......@@ -803,7 +802,7 @@ class FusionOptimizer(GraphRewriter):
for node in fuseable_clients_temp.get(next_out, ())
if node not in visited_nodes
),
key=lambda node: toposort.index(node),
key=lambda node: toposort_index[node],
):
fuseable_nodes_to_visit.append(next_node)
......@@ -877,20 +876,22 @@ class FusionOptimizer(GraphRewriter):
# client (those that don't fit into 1))
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
visited_nodes: set[Apply] = set()
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
while True:
starting_nodes = fg.apply_nodes.copy()
try:
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
fg=fg,
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
toposort_index=toposort_index,
)
except ValueError:
return
else:
# The caller is now expected to update fg in place,
# by replacing the subgraph with a Composite Op
starting_nodes = fg.apply_nodes.copy()
yield subgraph_inputs, subgraph_outputs
# This is where we avoid repeated work by using a stateful
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论