提交 d5d298a5 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Cleanup FusionOptimizer code

上级 71618f60
......@@ -5,7 +5,7 @@ import sys
from collections import defaultdict, deque
from collections.abc import Generator, Sequence
from functools import cache, reduce
from typing import TypeVar
from typing import Literal
from warnings import warn
import pytensor.scalar.basic as ps
......@@ -555,8 +555,6 @@ class FusionOptimizer(GraphRewriter):
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
max_operands = elemwise_max_operands_fct(None)
def find_next_fuseable_subgraph(
fg: FunctionGraph,
) -> Generator[tuple[list[Variable], list[Variable]], None, None]:
......@@ -568,8 +566,7 @@ class FusionOptimizer(GraphRewriter):
This generator assumes that such subgraph is replaced by a single
Elemwise Composite before being accessed again in the next iteration.
"""
FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]]
FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
def initialize_fuseable_mappings(
......@@ -591,35 +588,31 @@ class FusionOptimizer(GraphRewriter):
# to ensure the rewrite remains deterministic.
# This is not a problem from unfuseable ones, as they can never
# become part of the graph.
fuseable_clients: FUSEABLE_MAPPING = defaultdict(list)
fuseable_clients: FUSEABLE_MAPPING = defaultdict(set)
unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set)
for out, clients in fg.clients.items():
# Old FunctionGraph nodes remain in the clients dictionary
# even after they are removed by rewrites
if not clients:
continue
out_maybe_fuseable = (
out.owner
out.owner is not None
and isinstance(out.owner.op, Elemwise)
# and not isinstance(out.owner.op.scalar_op, ps.Composite)
and len(out.owner.outputs) == 1
and elemwise_scalar_op_has_c_code(out.owner)
)
if out_maybe_fuseable:
out_bcast = out.type.broadcastable
for client, _ in clients:
if (
out_maybe_fuseable
and isinstance(client.op, Elemwise)
isinstance(client.op, Elemwise)
# and not isinstance(client.op.scalar_op, ps.Composite)
and len(client.outputs) == 1
and out.type.broadcastable
== client.outputs[0].type.broadcastable
and out_bcast == client.outputs[0].type.broadcastable
and elemwise_scalar_op_has_c_code(client)
):
if client not in fuseable_clients[out]:
fuseable_clients[out].append(client)
fuseable_clients[out].add(client)
else:
unfuseable_clients[out].add(client)
else:
unfuseable_clients[out] = {client for client, _ in clients}
return fuseable_clients, unfuseable_clients
......@@ -630,16 +623,6 @@ class FusionOptimizer(GraphRewriter):
unfuseable_clients: UNFUSEABLE_MAPPING,
toposort_index: dict[Apply, int],
) -> tuple[list[Variable], list[Variable]]:
KT = TypeVar("KT")
VT = TypeVar("VT", list, set)
def shallow_clone_defaultdict(
d: defaultdict[KT, VT],
) -> defaultdict[KT, VT]:
new_dict: defaultdict[KT, VT] = defaultdict(d.default_factory)
new_dict.update({k: v.copy() for k, v in d.items()})
return new_dict
def variables_depend_on(
variables, depend_on, stop_search_at=None
) -> bool:
......@@ -657,17 +640,19 @@ class FusionOptimizer(GraphRewriter):
visited_nodes.add(starting_node)
continue
subgraph_inputs: list[Variable] = []
subgraph_outputs: list[Variable] = []
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
unfuseable_clients_subgraph: set[Variable] = set()
# Shallow cloning of maps so that they can be manipulated in place
fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients)
unfuseable_clients_clone = shallow_clone_defaultdict(
unfuseable_clients
fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set)
fuseable_clients_clone.update(
{k: v.copy() for k, v in fuseable_clients.items()}
)
unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set)
unfuseable_clients_clone.update(
{k: v.copy() for k, v in unfuseable_clients.items()}
)
fuseable_nodes_to_visit = deque([starting_node])
# We now try to expand as much as possible towards the potentially
# fuseable clients and ancestors to detect the largest possible
......@@ -676,6 +661,7 @@ class FusionOptimizer(GraphRewriter):
# some inputs or clients may depend on other nodes of the same
# subgraph via a path that cannot be included in the Composite
# (unfuseable)
fuseable_nodes_to_visit = deque([starting_node])
while fuseable_nodes_to_visit:
next_node = fuseable_nodes_to_visit.popleft()
visited_nodes.add(next_node)
......@@ -684,15 +670,14 @@ class FusionOptimizer(GraphRewriter):
# If the output variable of next_node has no fuseable clients
# or has unfuseable clients, then next_node must become an output
# if it is to be fused.
must_become_output = (
next_out not in fuseable_clients_temp
or next_out in unfuseable_clients_clone
)
must_become_output = not fuseable_clients_clone.get(
next_out
) or unfuseable_clients_clone.get(next_out)
# We have backtracked to this node, and it may no longer be a viable output,
# so we remove it and check again as if we had never seen this node
if must_become_output and next_out in subgraph_outputs:
subgraph_outputs.remove(next_out)
if must_become_output:
subgraph_outputs.pop(next_out, None)
required_unfuseable_inputs = [
inp
......@@ -744,18 +729,19 @@ class FusionOptimizer(GraphRewriter):
if (
inp.owner in visited_nodes
# next_node could have the same input repeated
and next_node in fuseable_clients_temp[inp]
and next_node in fuseable_clients_clone[inp]
):
fuseable_clients_temp[inp].remove(next_node)
fuseable_clients_clone[inp].remove(next_node)
unfuseable_clients_clone[inp].add(next_node)
# This input must become an output of the subgraph,
# because it can't be merged with next_node.
# We will revisit it to make sure this is safe.
fuseable_nodes_to_visit.appendleft(inp.owner)
for client in fuseable_clients_temp[next_out]:
# need to convert to tuple not to change set size during iteration
for client in tuple(fuseable_clients_clone[next_out]):
if client in visited_nodes:
fuseable_clients_temp[next_out].remove(client)
fuseable_clients_clone[next_out].remove(client)
unfuseable_clients_clone[next_out].add(client)
# next_out must become an input of the subgraph.
# We will revisit any of its clients currently
......@@ -771,50 +757,49 @@ class FusionOptimizer(GraphRewriter):
# mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite
for inp in new_required_unfuseable_inputs:
if inp not in subgraph_inputs:
subgraph_inputs.append(inp)
subgraph_inputs[inp] = None
if must_become_output:
subgraph_outputs.append(next_out)
subgraph_outputs[next_out] = None
unfuseable_clients_subgraph.update(
new_implied_unfuseable_clients
)
# Expand through unvisited fuseable ancestors
for inp in sorted(
fuseable_nodes_to_visit.extendleft(
sorted(
(
inp
inp.owner
for inp in next_node.inputs
if (
inp not in required_unfuseable_inputs
and inp.owner not in visited_nodes
)
),
key=lambda inp: toposort_index[inp.owner],
reverse=True,
):
fuseable_nodes_to_visit.appendleft(inp.owner)
key=toposort_index.get, # type: ignore[arg-type]
)
)
# Expand through unvisited fuseable clients
for next_node in sorted(
fuseable_nodes_to_visit.extend(
sorted(
(
node
for node in fuseable_clients_temp.get(next_out, ())
for node in fuseable_clients_clone.get(next_out, ())
if node not in visited_nodes
),
key=lambda node: toposort_index[node],
):
fuseable_nodes_to_visit.append(next_node)
key=toposort_index.get, # type: ignore[arg-type]
)
)
# Don't return if final subgraph is just the original Elemwise
if len(subgraph_outputs) == 1 and set(
subgraph_outputs[0].owner.inputs
next(iter(subgraph_outputs)).owner.inputs
) == set(subgraph_inputs):
# Update global fuseable mappings
# No input was actually fuseable
for inp in starting_node.inputs:
if starting_node in fuseable_clients.get(inp, ()):
fuseable_clients[inp].remove(starting_node)
fuseable_clients[inp].discard(starting_node)
unfuseable_clients[inp].add(starting_node)
# No client was actually fuseable
unfuseable_clients[starting_out].update(
......@@ -822,23 +807,22 @@ class FusionOptimizer(GraphRewriter):
)
continue
return subgraph_inputs, subgraph_outputs
return list(subgraph_inputs), list(subgraph_outputs)
raise ValueError
def update_fuseable_mappings_after_fg_replace(
*,
fg: FunctionGraph,
visited_nodes: set[Apply],
fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING,
starting_nodes: set[Apply],
updated_nodes: set[Apply],
) -> None:
# Find new composite node and dropped intermediate nodes
# by comparing the current fg.apply nodes with the cached
# original nodes
next_nodes = fg.apply_nodes
(new_composite_node,) = next_nodes - starting_nodes
dropped_nodes = starting_nodes - next_nodes
(new_composite_node,) = updated_nodes - starting_nodes
dropped_nodes = starting_nodes - updated_nodes
# Remove intermediate Composite nodes from mappings
for dropped_node in dropped_nodes:
......@@ -850,11 +834,11 @@ class FusionOptimizer(GraphRewriter):
# Update fuseable information for subgraph inputs
for inp in subgraph_inputs:
if inp in fuseable_clients:
new_fuseable_clients = [
new_fuseable_clients = {
client
for client in fuseable_clients[inp]
if client not in dropped_nodes
]
}
if new_fuseable_clients:
fuseable_clients[inp] = new_fuseable_clients
else:
......@@ -898,13 +882,15 @@ class FusionOptimizer(GraphRewriter):
# generator. For large models (as in `TestFusion.test_big_fusion`)
# this can provide huge speedups
update_fuseable_mappings_after_fg_replace(
fg=fg,
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
starting_nodes=starting_nodes,
updated_nodes=fg.apply_nodes,
)
max_operands = elemwise_max_operands_fct(None)
reason = self.__class__.__name__
nb_fused = 0
nb_replacement = 0
for inputs, outputs in find_next_fuseable_subgraph(fgraph):
......@@ -923,13 +909,12 @@ class FusionOptimizer(GraphRewriter):
assert len(outputs) == len(composite_outputs)
for old_out, composite_out in zip(outputs, composite_outputs):
# Preserve any names on the original outputs
if old_out.name:
composite_out.name = old_out.name
if old_name := old_out.name:
composite_out.name = old_name
starting_nodes = len(fgraph.apply_nodes)
fgraph.replace_all_validate(
list(zip(outputs, composite_outputs, strict=True)),
reason=self.__class__.__name__,
tuple(zip(outputs, composite_outputs)), reason=reason
)
nb_fused += 1
nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论