提交 dc1e3b9c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Copy on write in FusionOptimizer

上级 d5d298a5
...@@ -2,6 +2,7 @@ import abc ...@@ -2,6 +2,7 @@ import abc
import itertools import itertools
import operator import operator
import sys import sys
import typing
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from functools import cache, reduce from functools import cache, reduce
...@@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int: ...@@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int:
return 1024 return 1024
class CopyOnWriteDictOfSets:
__slots__ = ("d", "d_copy")
def __init__(self, d: dict[typing.Any, set]):
self.d = d
self.d_copy: dict[typing.Any, set] = {}
def __getitem__(self, key):
try:
return self.d_copy[key]
except KeyError:
return self.d[key]
def get(self, key, default=frozenset()):
try:
return self.d_copy[key]
except KeyError:
try:
return self.d[key]
except KeyError:
return default
def remove_from_key(self, key, value):
try:
self.d_copy[key].remove(value)
except KeyError:
self.d_copy[key] = copied_value = self.d[key].copy()
copied_value.remove(value)
def add_to_key(self, key, value):
try:
self.d_copy[key].add(value)
except KeyError:
self.d_copy[key] = copied_value = self.d[key].copy()
copied_value.add(value)
class FusionOptimizer(GraphRewriter): class FusionOptimizer(GraphRewriter):
"""Graph optimizer that fuses consecutive Elemwise operations.""" """Graph optimizer that fuses consecutive Elemwise operations."""
...@@ -644,15 +682,10 @@ class FusionOptimizer(GraphRewriter): ...@@ -644,15 +682,10 @@ class FusionOptimizer(GraphRewriter):
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
unfuseable_clients_subgraph: set[Variable] = set() unfuseable_clients_subgraph: set[Variable] = set()
# Shallow cloning of maps so that they can be manipulated in place # If we need to manipulate the maps in place, we'll do a shallow copy later
fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set) # For now we query on the original ones
fuseable_clients_clone.update( fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients)
{k: v.copy() for k, v in fuseable_clients.items()} unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients)
)
unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set)
unfuseable_clients_clone.update(
{k: v.copy() for k, v in unfuseable_clients.items()}
)
# We now try to expand as much as possible towards the potentially # We now try to expand as much as possible towards the potentially
# fuseable clients and ancestors to detect the largest possible # fuseable clients and ancestors to detect the largest possible
...@@ -682,7 +715,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -682,7 +715,7 @@ class FusionOptimizer(GraphRewriter):
required_unfuseable_inputs = [ required_unfuseable_inputs = [
inp inp
for inp in next_node.inputs for inp in next_node.inputs
if next_node in unfuseable_clients_clone.get(inp, ()) if next_node in unfuseable_clients_clone.get(inp)
] ]
new_required_unfuseable_inputs = [ new_required_unfuseable_inputs = [
inp inp
...@@ -705,7 +738,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -705,7 +738,7 @@ class FusionOptimizer(GraphRewriter):
if not must_backtrack: if not must_backtrack:
implied_unfuseable_clients = { implied_unfuseable_clients = {
c c
for client in unfuseable_clients_clone.get(next_out, ()) for client in unfuseable_clients_clone.get(next_out)
if not isinstance(client.op, Output) if not isinstance(client.op, Output)
for c in client.outputs for c in client.outputs
} }
...@@ -726,13 +759,15 @@ class FusionOptimizer(GraphRewriter): ...@@ -726,13 +759,15 @@ class FusionOptimizer(GraphRewriter):
if must_backtrack: if must_backtrack:
for inp in next_node.inputs: for inp in next_node.inputs:
if ( if inp.owner in visited_nodes:
inp.owner in visited_nodes if next_node not in fuseable_clients_clone[inp]:
# next_node could have the same input repeated # This can happen when next node has repeated inputs
and next_node in fuseable_clients_clone[inp] continue
): fuseable_clients_clone.remove_from_key(
fuseable_clients_clone[inp].remove(next_node) inp, next_node
unfuseable_clients_clone[inp].add(next_node) )
unfuseable_clients_clone.add_to_key(inp, next_node)
# This input must become an output of the subgraph, # This input must become an output of the subgraph,
# because it can't be merged with next_node. # because it can't be merged with next_node.
# We will revisit it to make sure this is safe. # We will revisit it to make sure this is safe.
...@@ -741,8 +776,13 @@ class FusionOptimizer(GraphRewriter): ...@@ -741,8 +776,13 @@ class FusionOptimizer(GraphRewriter):
# need to convert to tuple not to change set size during iteration # need to convert to tuple not to change set size during iteration
for client in tuple(fuseable_clients_clone[next_out]): for client in tuple(fuseable_clients_clone[next_out]):
if client in visited_nodes: if client in visited_nodes:
fuseable_clients_clone[next_out].remove(client) fuseable_clients_clone.remove_from_key(
unfuseable_clients_clone[next_out].add(client) next_out, client
)
unfuseable_clients_clone.add_to_key(
next_out, client
)
# next_out must become an input of the subgraph. # next_out must become an input of the subgraph.
# We will revisit any of its clients currently # We will revisit any of its clients currently
# in the subgraph to make sure this is safe. # in the subgraph to make sure this is safe.
...@@ -785,7 +825,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -785,7 +825,7 @@ class FusionOptimizer(GraphRewriter):
sorted( sorted(
( (
node node
for node in fuseable_clients_clone.get(next_out, ()) for node in fuseable_clients_clone.get(next_out)
if node not in visited_nodes if node not in visited_nodes
), ),
key=toposort_index.get, # type: ignore[arg-type] key=toposort_index.get, # type: ignore[arg-type]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论