提交 dc1e3b9c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Copy on write in FusionOptimizer

上级 d5d298a5
......@@ -2,6 +2,7 @@ import abc
import itertools
import operator
import sys
import typing
from collections import defaultdict, deque
from collections.abc import Generator, Sequence
from functools import cache, reduce
......@@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int:
return 1024
class CopyOnWriteDictOfSets:
__slots__ = ("d", "d_copy")
def __init__(self, d: dict[typing.Any, set]):
self.d = d
self.d_copy: dict[typing.Any, set] = {}
def __getitem__(self, key):
try:
return self.d_copy[key]
except KeyError:
return self.d[key]
def get(self, key, default=frozenset()):
try:
return self.d_copy[key]
except KeyError:
try:
return self.d[key]
except KeyError:
return default
def remove_from_key(self, key, value):
try:
self.d_copy[key].remove(value)
except KeyError:
self.d_copy[key] = copied_value = self.d[key].copy()
copied_value.remove(value)
def add_to_key(self, key, value):
try:
self.d_copy[key].add(value)
except KeyError:
self.d_copy[key] = copied_value = self.d[key].copy()
copied_value.add(value)
class FusionOptimizer(GraphRewriter):
"""Graph optimizer that fuses consecutive Elemwise operations."""
......@@ -644,15 +682,10 @@ class FusionOptimizer(GraphRewriter):
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
unfuseable_clients_subgraph: set[Variable] = set()
# Shallow cloning of maps so that they can be manipulated in place
fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set)
fuseable_clients_clone.update(
{k: v.copy() for k, v in fuseable_clients.items()}
)
unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set)
unfuseable_clients_clone.update(
{k: v.copy() for k, v in unfuseable_clients.items()}
)
# If we need to manipulate the maps in place, we'll do a shallow copy later
# For now we query on the original ones
fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients)
unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients)
# We now try to expand as much as possible towards the potentially
# fuseable clients and ancestors to detect the largest possible
......@@ -682,7 +715,7 @@ class FusionOptimizer(GraphRewriter):
required_unfuseable_inputs = [
inp
for inp in next_node.inputs
if next_node in unfuseable_clients_clone.get(inp, ())
if next_node in unfuseable_clients_clone.get(inp)
]
new_required_unfuseable_inputs = [
inp
......@@ -705,7 +738,7 @@ class FusionOptimizer(GraphRewriter):
if not must_backtrack:
implied_unfuseable_clients = {
c
for client in unfuseable_clients_clone.get(next_out, ())
for client in unfuseable_clients_clone.get(next_out)
if not isinstance(client.op, Output)
for c in client.outputs
}
......@@ -726,13 +759,15 @@ class FusionOptimizer(GraphRewriter):
if must_backtrack:
for inp in next_node.inputs:
if (
inp.owner in visited_nodes
# next_node could have the same input repeated
and next_node in fuseable_clients_clone[inp]
):
fuseable_clients_clone[inp].remove(next_node)
unfuseable_clients_clone[inp].add(next_node)
if inp.owner in visited_nodes:
if next_node not in fuseable_clients_clone[inp]:
# This can happen when next node has repeated inputs
continue
fuseable_clients_clone.remove_from_key(
inp, next_node
)
unfuseable_clients_clone.add_to_key(inp, next_node)
# This input must become an output of the subgraph,
# because it can't be merged with next_node.
# We will revisit it to make sure this is safe.
......@@ -741,8 +776,13 @@ class FusionOptimizer(GraphRewriter):
# need to convert to tuple not to change set size during iteration
for client in tuple(fuseable_clients_clone[next_out]):
if client in visited_nodes:
fuseable_clients_clone[next_out].remove(client)
unfuseable_clients_clone[next_out].add(client)
fuseable_clients_clone.remove_from_key(
next_out, client
)
unfuseable_clients_clone.add_to_key(
next_out, client
)
# next_out must become an input of the subgraph.
# We will revisit any of its clients currently
# in the subgraph to make sure this is safe.
......@@ -785,7 +825,7 @@ class FusionOptimizer(GraphRewriter):
sorted(
(
node
for node in fuseable_clients_clone.get(next_out, ())
for node in fuseable_clients_clone.get(next_out)
if node not in visited_nodes
),
key=toposort_index.get, # type: ignore[arg-type]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论