提交 71618f60 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Do not recompute toposort in every iteration of FusionOptimizer

It's not really needed as we never expand on the new nodes
上级 9ef575b7
...@@ -625,10 +625,10 @@ class FusionOptimizer(GraphRewriter): ...@@ -625,10 +625,10 @@ class FusionOptimizer(GraphRewriter):
def find_fuseable_subgraph( def find_fuseable_subgraph(
*, *,
fg: FunctionGraph,
visited_nodes: set[Apply], visited_nodes: set[Apply],
fuseable_clients: FUSEABLE_MAPPING, fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING,
toposort_index: dict[Apply, int],
) -> tuple[list[Variable], list[Variable]]: ) -> tuple[list[Variable], list[Variable]]:
KT = TypeVar("KT") KT = TypeVar("KT")
VT = TypeVar("VT", list, set) VT = TypeVar("VT", list, set)
...@@ -648,8 +648,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -648,8 +648,7 @@ class FusionOptimizer(GraphRewriter):
for a in ancestors(variables, blockers=stop_search_at) for a in ancestors(variables, blockers=stop_search_at)
) )
toposort = fg.toposort() for starting_node in toposort_index:
for starting_node in toposort:
if starting_node in visited_nodes: if starting_node in visited_nodes:
continue continue
...@@ -791,7 +790,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -791,7 +790,7 @@ class FusionOptimizer(GraphRewriter):
and inp.owner not in visited_nodes and inp.owner not in visited_nodes
) )
), ),
key=lambda inp: toposort.index(inp.owner), key=lambda inp: toposort_index[inp.owner],
reverse=True, reverse=True,
): ):
fuseable_nodes_to_visit.appendleft(inp.owner) fuseable_nodes_to_visit.appendleft(inp.owner)
...@@ -803,7 +802,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -803,7 +802,7 @@ class FusionOptimizer(GraphRewriter):
for node in fuseable_clients_temp.get(next_out, ()) for node in fuseable_clients_temp.get(next_out, ())
if node not in visited_nodes if node not in visited_nodes
), ),
key=lambda node: toposort.index(node), key=lambda node: toposort_index[node],
): ):
fuseable_nodes_to_visit.append(next_node) fuseable_nodes_to_visit.append(next_node)
...@@ -877,20 +876,22 @@ class FusionOptimizer(GraphRewriter): ...@@ -877,20 +876,22 @@ class FusionOptimizer(GraphRewriter):
# client (those that don't fit into 1)) # client (those that don't fit into 1))
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
visited_nodes: set[Apply] = set() visited_nodes: set[Apply] = set()
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
while True: while True:
starting_nodes = fg.apply_nodes.copy()
try: try:
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
fg=fg,
visited_nodes=visited_nodes, visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients, fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients, unfuseable_clients=unfuseable_clients,
toposort_index=toposort_index,
) )
except ValueError: except ValueError:
return return
else: else:
# The caller is now expected to update fg in place, # The caller is now expected to update fg in place,
# by replacing the subgraph with a Composite Op # by replacing the subgraph with a Composite Op
starting_nodes = fg.apply_nodes.copy()
yield subgraph_inputs, subgraph_outputs yield subgraph_inputs, subgraph_outputs
# This is where we avoid repeated work by using a stateful # This is where we avoid repeated work by using a stateful
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论