提交 d5d298a5 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Cleanup FusionOptimizer code

上级 71618f60
...@@ -5,7 +5,7 @@ import sys ...@@ -5,7 +5,7 @@ import sys
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from functools import cache, reduce from functools import cache, reduce
from typing import TypeVar from typing import Literal
from warnings import warn from warnings import warn
import pytensor.scalar.basic as ps import pytensor.scalar.basic as ps
...@@ -555,8 +555,6 @@ class FusionOptimizer(GraphRewriter): ...@@ -555,8 +555,6 @@ class FusionOptimizer(GraphRewriter):
callbacks_before = fgraph.execute_callbacks_times.copy() callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
max_operands = elemwise_max_operands_fct(None)
def find_next_fuseable_subgraph( def find_next_fuseable_subgraph(
fg: FunctionGraph, fg: FunctionGraph,
) -> Generator[tuple[list[Variable], list[Variable]], None, None]: ) -> Generator[tuple[list[Variable], list[Variable]], None, None]:
...@@ -568,8 +566,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -568,8 +566,7 @@ class FusionOptimizer(GraphRewriter):
This generator assumes that such subgraph is replaced by a single This generator assumes that such subgraph is replaced by a single
Elemwise Composite before being accessed again in the next iteration. Elemwise Composite before being accessed again in the next iteration.
""" """
FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]]
UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
def initialize_fuseable_mappings( def initialize_fuseable_mappings(
...@@ -591,35 +588,31 @@ class FusionOptimizer(GraphRewriter): ...@@ -591,35 +588,31 @@ class FusionOptimizer(GraphRewriter):
# to ensure the rewrite remains deterministic. # to ensure the rewrite remains deterministic.
# This is not a problem from unfuseable ones, as they can never # This is not a problem from unfuseable ones, as they can never
# become part of the graph. # become part of the graph.
fuseable_clients: FUSEABLE_MAPPING = defaultdict(list) fuseable_clients: FUSEABLE_MAPPING = defaultdict(set)
unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set)
for out, clients in fg.clients.items(): for out, clients in fg.clients.items():
# Old FunctionGraph nodes remain in the clients dictionary
# even after they are removed by rewrites
if not clients:
continue
out_maybe_fuseable = ( out_maybe_fuseable = (
out.owner out.owner is not None
and isinstance(out.owner.op, Elemwise) and isinstance(out.owner.op, Elemwise)
# and not isinstance(out.owner.op.scalar_op, ps.Composite) # and not isinstance(out.owner.op.scalar_op, ps.Composite)
and len(out.owner.outputs) == 1 and len(out.owner.outputs) == 1
and elemwise_scalar_op_has_c_code(out.owner) and elemwise_scalar_op_has_c_code(out.owner)
) )
for client, _ in clients: if out_maybe_fuseable:
if ( out_bcast = out.type.broadcastable
out_maybe_fuseable for client, _ in clients:
and isinstance(client.op, Elemwise) if (
# and not isinstance(client.op.scalar_op, ps.Composite) isinstance(client.op, Elemwise)
and len(client.outputs) == 1 # and not isinstance(client.op.scalar_op, ps.Composite)
and out.type.broadcastable and len(client.outputs) == 1
== client.outputs[0].type.broadcastable and out_bcast == client.outputs[0].type.broadcastable
and elemwise_scalar_op_has_c_code(client) and elemwise_scalar_op_has_c_code(client)
): ):
if client not in fuseable_clients[out]: fuseable_clients[out].add(client)
fuseable_clients[out].append(client) else:
else: unfuseable_clients[out].add(client)
unfuseable_clients[out].add(client) else:
unfuseable_clients[out] = {client for client, _ in clients}
return fuseable_clients, unfuseable_clients return fuseable_clients, unfuseable_clients
...@@ -630,16 +623,6 @@ class FusionOptimizer(GraphRewriter): ...@@ -630,16 +623,6 @@ class FusionOptimizer(GraphRewriter):
unfuseable_clients: UNFUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING,
toposort_index: dict[Apply, int], toposort_index: dict[Apply, int],
) -> tuple[list[Variable], list[Variable]]: ) -> tuple[list[Variable], list[Variable]]:
KT = TypeVar("KT")
VT = TypeVar("VT", list, set)
def shallow_clone_defaultdict(
d: defaultdict[KT, VT],
) -> defaultdict[KT, VT]:
new_dict: defaultdict[KT, VT] = defaultdict(d.default_factory)
new_dict.update({k: v.copy() for k, v in d.items()})
return new_dict
def variables_depend_on( def variables_depend_on(
variables, depend_on, stop_search_at=None variables, depend_on, stop_search_at=None
) -> bool: ) -> bool:
...@@ -657,17 +640,19 @@ class FusionOptimizer(GraphRewriter): ...@@ -657,17 +640,19 @@ class FusionOptimizer(GraphRewriter):
visited_nodes.add(starting_node) visited_nodes.add(starting_node)
continue continue
subgraph_inputs: list[Variable] = [] subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
subgraph_outputs: list[Variable] = [] subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
unfuseable_clients_subgraph: set[Variable] = set() unfuseable_clients_subgraph: set[Variable] = set()
# Shallow cloning of maps so that they can be manipulated in place # Shallow cloning of maps so that they can be manipulated in place
fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients) fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set)
unfuseable_clients_clone = shallow_clone_defaultdict( fuseable_clients_clone.update(
unfuseable_clients {k: v.copy() for k, v in fuseable_clients.items()}
)
unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set)
unfuseable_clients_clone.update(
{k: v.copy() for k, v in unfuseable_clients.items()}
) )
fuseable_nodes_to_visit = deque([starting_node])
# We now try to expand as much as possible towards the potentially # We now try to expand as much as possible towards the potentially
# fuseable clients and ancestors to detect the largest possible # fuseable clients and ancestors to detect the largest possible
...@@ -676,6 +661,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -676,6 +661,7 @@ class FusionOptimizer(GraphRewriter):
# some inputs or clients may depend on other nodes of the same # some inputs or clients may depend on other nodes of the same
# subgraph via a path that cannot be included in the Composite # subgraph via a path that cannot be included in the Composite
# (unfuseable) # (unfuseable)
fuseable_nodes_to_visit = deque([starting_node])
while fuseable_nodes_to_visit: while fuseable_nodes_to_visit:
next_node = fuseable_nodes_to_visit.popleft() next_node = fuseable_nodes_to_visit.popleft()
visited_nodes.add(next_node) visited_nodes.add(next_node)
...@@ -684,15 +670,14 @@ class FusionOptimizer(GraphRewriter): ...@@ -684,15 +670,14 @@ class FusionOptimizer(GraphRewriter):
# If the output variable of next_node has no fuseable clients # If the output variable of next_node has no fuseable clients
# or has unfuseable clients, then next_node must become an output # or has unfuseable clients, then next_node must become an output
# if it is to be fused. # if it is to be fused.
must_become_output = ( must_become_output = not fuseable_clients_clone.get(
next_out not in fuseable_clients_temp next_out
or next_out in unfuseable_clients_clone ) or unfuseable_clients_clone.get(next_out)
)
# We have backtracked to this node, and it may no longer be a viable output, # We have backtracked to this node, and it may no longer be a viable output,
# so we remove it and check again as if we had never seen this node # so we remove it and check again as if we had never seen this node
if must_become_output and next_out in subgraph_outputs: if must_become_output:
subgraph_outputs.remove(next_out) subgraph_outputs.pop(next_out, None)
required_unfuseable_inputs = [ required_unfuseable_inputs = [
inp inp
...@@ -744,18 +729,19 @@ class FusionOptimizer(GraphRewriter): ...@@ -744,18 +729,19 @@ class FusionOptimizer(GraphRewriter):
if ( if (
inp.owner in visited_nodes inp.owner in visited_nodes
# next_node could have the same input repeated # next_node could have the same input repeated
and next_node in fuseable_clients_temp[inp] and next_node in fuseable_clients_clone[inp]
): ):
fuseable_clients_temp[inp].remove(next_node) fuseable_clients_clone[inp].remove(next_node)
unfuseable_clients_clone[inp].add(next_node) unfuseable_clients_clone[inp].add(next_node)
# This input must become an output of the subgraph, # This input must become an output of the subgraph,
# because it can't be merged with next_node. # because it can't be merged with next_node.
# We will revisit it to make sure this is safe. # We will revisit it to make sure this is safe.
fuseable_nodes_to_visit.appendleft(inp.owner) fuseable_nodes_to_visit.appendleft(inp.owner)
for client in fuseable_clients_temp[next_out]: # need to convert to tuple not to change set size during iteration
for client in tuple(fuseable_clients_clone[next_out]):
if client in visited_nodes: if client in visited_nodes:
fuseable_clients_temp[next_out].remove(client) fuseable_clients_clone[next_out].remove(client)
unfuseable_clients_clone[next_out].add(client) unfuseable_clients_clone[next_out].add(client)
# next_out must become an input of the subgraph. # next_out must become an input of the subgraph.
# We will revisit any of its clients currently # We will revisit any of its clients currently
...@@ -771,74 +757,72 @@ class FusionOptimizer(GraphRewriter): ...@@ -771,74 +757,72 @@ class FusionOptimizer(GraphRewriter):
# mappings as if it next_node was part of it. # mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite # Useless inputs will be removed by the useless Composite rewrite
for inp in new_required_unfuseable_inputs: for inp in new_required_unfuseable_inputs:
if inp not in subgraph_inputs: subgraph_inputs[inp] = None
subgraph_inputs.append(inp)
if must_become_output: if must_become_output:
subgraph_outputs.append(next_out) subgraph_outputs[next_out] = None
unfuseable_clients_subgraph.update( unfuseable_clients_subgraph.update(
new_implied_unfuseable_clients new_implied_unfuseable_clients
) )
# Expand through unvisited fuseable ancestors # Expand through unvisited fuseable ancestors
for inp in sorted( fuseable_nodes_to_visit.extendleft(
( sorted(
inp (
for inp in next_node.inputs inp.owner
if ( for inp in next_node.inputs
inp not in required_unfuseable_inputs if (
and inp.owner not in visited_nodes inp not in required_unfuseable_inputs
) and inp.owner not in visited_nodes
), )
key=lambda inp: toposort_index[inp.owner], ),
reverse=True, key=toposort_index.get, # type: ignore[arg-type]
): )
fuseable_nodes_to_visit.appendleft(inp.owner) )
# Expand through unvisited fuseable clients # Expand through unvisited fuseable clients
for next_node in sorted( fuseable_nodes_to_visit.extend(
( sorted(
node (
for node in fuseable_clients_temp.get(next_out, ()) node
if node not in visited_nodes for node in fuseable_clients_clone.get(next_out, ())
), if node not in visited_nodes
key=lambda node: toposort_index[node], ),
): key=toposort_index.get, # type: ignore[arg-type]
fuseable_nodes_to_visit.append(next_node) )
)
# Don't return if final subgraph is just the original Elemwise # Don't return if final subgraph is just the original Elemwise
if len(subgraph_outputs) == 1 and set( if len(subgraph_outputs) == 1 and set(
subgraph_outputs[0].owner.inputs next(iter(subgraph_outputs)).owner.inputs
) == set(subgraph_inputs): ) == set(subgraph_inputs):
# Update global fuseable mappings # Update global fuseable mappings
# No input was actually fuseable # No input was actually fuseable
for inp in starting_node.inputs: for inp in starting_node.inputs:
if starting_node in fuseable_clients.get(inp, ()): fuseable_clients[inp].discard(starting_node)
fuseable_clients[inp].remove(starting_node) unfuseable_clients[inp].add(starting_node)
unfuseable_clients[inp].add(starting_node)
# No client was actually fuseable # No client was actually fuseable
unfuseable_clients[starting_out].update( unfuseable_clients[starting_out].update(
fuseable_clients.pop(starting_out, ()) fuseable_clients.pop(starting_out, ())
) )
continue continue
return subgraph_inputs, subgraph_outputs return list(subgraph_inputs), list(subgraph_outputs)
raise ValueError raise ValueError
def update_fuseable_mappings_after_fg_replace( def update_fuseable_mappings_after_fg_replace(
*, *,
fg: FunctionGraph,
visited_nodes: set[Apply], visited_nodes: set[Apply],
fuseable_clients: FUSEABLE_MAPPING, fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING,
starting_nodes: set[Apply], starting_nodes: set[Apply],
updated_nodes: set[Apply],
) -> None: ) -> None:
# Find new composite node and dropped intermediate nodes # Find new composite node and dropped intermediate nodes
# by comparing the current fg.apply nodes with the cached # by comparing the current fg.apply nodes with the cached
# original nodes # original nodes
next_nodes = fg.apply_nodes (new_composite_node,) = updated_nodes - starting_nodes
(new_composite_node,) = next_nodes - starting_nodes dropped_nodes = starting_nodes - updated_nodes
dropped_nodes = starting_nodes - next_nodes
# Remove intermediate Composite nodes from mappings # Remove intermediate Composite nodes from mappings
for dropped_node in dropped_nodes: for dropped_node in dropped_nodes:
...@@ -850,11 +834,11 @@ class FusionOptimizer(GraphRewriter): ...@@ -850,11 +834,11 @@ class FusionOptimizer(GraphRewriter):
# Update fuseable information for subgraph inputs # Update fuseable information for subgraph inputs
for inp in subgraph_inputs: for inp in subgraph_inputs:
if inp in fuseable_clients: if inp in fuseable_clients:
new_fuseable_clients = [ new_fuseable_clients = {
client client
for client in fuseable_clients[inp] for client in fuseable_clients[inp]
if client not in dropped_nodes if client not in dropped_nodes
] }
if new_fuseable_clients: if new_fuseable_clients:
fuseable_clients[inp] = new_fuseable_clients fuseable_clients[inp] = new_fuseable_clients
else: else:
...@@ -898,13 +882,15 @@ class FusionOptimizer(GraphRewriter): ...@@ -898,13 +882,15 @@ class FusionOptimizer(GraphRewriter):
# generator. For large models (as in `TestFusion.test_big_fusion`) # generator. For large models (as in `TestFusion.test_big_fusion`)
# this can provide huge speedups # this can provide huge speedups
update_fuseable_mappings_after_fg_replace( update_fuseable_mappings_after_fg_replace(
fg=fg,
visited_nodes=visited_nodes, visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients, fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients, unfuseable_clients=unfuseable_clients,
starting_nodes=starting_nodes, starting_nodes=starting_nodes,
updated_nodes=fg.apply_nodes,
) )
max_operands = elemwise_max_operands_fct(None)
reason = self.__class__.__name__
nb_fused = 0 nb_fused = 0
nb_replacement = 0 nb_replacement = 0
for inputs, outputs in find_next_fuseable_subgraph(fgraph): for inputs, outputs in find_next_fuseable_subgraph(fgraph):
...@@ -923,13 +909,12 @@ class FusionOptimizer(GraphRewriter): ...@@ -923,13 +909,12 @@ class FusionOptimizer(GraphRewriter):
assert len(outputs) == len(composite_outputs) assert len(outputs) == len(composite_outputs)
for old_out, composite_out in zip(outputs, composite_outputs): for old_out, composite_out in zip(outputs, composite_outputs):
# Preserve any names on the original outputs # Preserve any names on the original outputs
if old_out.name: if old_name := old_out.name:
composite_out.name = old_out.name composite_out.name = old_name
starting_nodes = len(fgraph.apply_nodes) starting_nodes = len(fgraph.apply_nodes)
fgraph.replace_all_validate( fgraph.replace_all_validate(
list(zip(outputs, composite_outputs, strict=True)), tuple(zip(outputs, composite_outputs)), reason=reason
reason=self.__class__.__name__,
) )
nb_fused += 1 nb_fused += 1
nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1 nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论