提交 d26374cd authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Avoid backtracking in FusionOptimizer

The change in number of fused kernels has to do with the order of iteration, and could be replicated in the old approach by iterating in topological order. It was an accident that it happen to visit in an order where it connected two branches, instead of keeping them separate. The underlying limitation already existed and is described in https://github.com/pymc-devs/pytensor/issues/249
上级 566af64d
...@@ -2,12 +2,10 @@ import abc ...@@ -2,12 +2,10 @@ import abc
import itertools import itertools
import operator import operator
import sys import sys
import typing
from collections import defaultdict, deque
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from functools import cache, reduce from functools import cache, reduce
from heapq import heapify, heappop, heappush
from operator import or_ from operator import or_
from typing import Literal
from warnings import warn from warnings import warn
import pytensor.scalar.basic as ps import pytensor.scalar.basic as ps
...@@ -524,43 +522,6 @@ def elemwise_max_operands_fct(node) -> int: ...@@ -524,43 +522,6 @@ def elemwise_max_operands_fct(node) -> int:
return 1024 return 1024
class CopyOnWriteDictOfSets:
__slots__ = ("d", "d_copy")
def __init__(self, d: dict[typing.Any, set]):
self.d = d
self.d_copy: dict[typing.Any, set] = {}
def __getitem__(self, key):
try:
return self.d_copy[key]
except KeyError:
return self.d[key]
def get(self, key, default=frozenset()):
try:
return self.d_copy[key]
except KeyError:
try:
return self.d[key]
except KeyError:
return default
def remove_from_key(self, key, value):
try:
self.d_copy[key].remove(value)
except KeyError:
self.d_copy[key] = copied_value = self.d[key].copy()
copied_value.remove(value)
def add_to_key(self, key, value):
try:
self.d_copy[key].add(value)
except KeyError:
self.d_copy[key] = copied_value = self.d[key].copy()
copied_value.add(value)
class FusionOptimizer(GraphRewriter): class FusionOptimizer(GraphRewriter):
"""Graph optimizer that fuses consecutive Elemwise operations.""" """Graph optimizer that fuses consecutive Elemwise operations."""
...@@ -594,353 +555,300 @@ class FusionOptimizer(GraphRewriter): ...@@ -594,353 +555,300 @@ class FusionOptimizer(GraphRewriter):
callbacks_before = fgraph.execute_callbacks_times.copy() callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
def find_next_fuseable_subgraph( def find_fuseable_subgraphs(
fg: FunctionGraph, fg: FunctionGraph,
) -> Generator[tuple[list[Variable], list[Variable]], None, None]: ) -> Generator[tuple[tuple[Variable], tuple[Variable]], None, None]:
"""Find all subgraphs in a FunctionGraph that can be fused together """Find subgraphs of Elemwise nodes that can be fused together.
Yields In general, there is no single solution. We try to find large subgraphs eagerly
-------
List of inputs and outputs that determine subgraphs which can be fused. Any two consecutive Elemwise nodes that have the same broadcasting pattern,
This generator assumes that such subgraph is replaced by a single and a C-implementation (historical accident that should be revisited), are potentially fuseable.
Elemwise Composite before being accessed again in the next iteration.
However, not all collections of fuseable pairs make a valid fused subgraph.
A valid fused subgraph must be "convex", meaning that no two nodes in the subgraph
are connected via a path that goes outside the subgraph, either because they
are connected via unfuseable nodes, or nodes that have been claimed by another fused subgraph.
For example the subgraph add(sin(exp(x)), sum(exp(x)) cannot be fused together,
because the sum node breaks the convexity of the subgraph {exp, sin, add}.
However, we can fuse {exp, sin}, and perhaps fuse add with something else.
This function yields subgraph in reverse topological order so they can be safely replaced one at a time
""" """
FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] @cache
def elemwise_scalar_op_has_c_code(
def initialize_fuseable_mappings( node: Apply, optimizer_verbose=config.optimizer_verbose
*, fg: FunctionGraph ) -> bool:
) -> tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]: # TODO: This should not play a role in non-c backends!
@cache if node.op.scalar_op.supports_c_code(node.inputs, node.outputs):
def elemwise_scalar_op_has_c_code(node: Apply) -> bool: return True
# TODO: This should not play a role in non-c backends! elif optimizer_verbose:
if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): warn(
return True f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation."
else:
if config.optimizer_verbose:
warn(
f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation."
)
return False
# Fuseable nodes have to be accessed in a deterministic manner
# to ensure the rewrite remains deterministic.
# This is not a problem from unfuseable ones, as they can never
# become part of the graph.
fuseable_clients: FUSEABLE_MAPPING = defaultdict(set)
unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set)
for out, clients in fg.clients.items():
out_maybe_fuseable = (
out.owner is not None
and isinstance(out.owner.op, Elemwise)
# and not isinstance(out.owner.op.scalar_op, ps.Composite)
and len(out.owner.outputs) == 1
and elemwise_scalar_op_has_c_code(out.owner)
) )
if out_maybe_fuseable: return False
out_bcast = out.type.broadcastable
for client, _ in clients: # Create a map from node to a set of fuseable client (successor) nodes
if ( # A node and a client are fuseable if they are both single output Elemwise
isinstance(client.op, Elemwise) # (with C-implementation) and have the same output broadcastable pattern
# and not isinstance(client.op.scalar_op, ps.Composite) # Nodes that have no fuseable clients are not included
and len(client.outputs) == 1 fuseable_clients: dict[Apply, set[Apply]] = {}
and out_bcast == client.outputs[0].type.broadcastable # We also create a set with candidate nodes from which to start a subgraph expansion
and elemwise_scalar_op_has_c_code(client) # These are Single output Elemwise nodes (with C-implementation) that may or not
): # have fuseable ancestors/clients at the start.
fuseable_clients[out].add(client) candidate_starting_nodes = set()
else: fg_clients = fg.clients
unfuseable_clients[out].add(client) for out, clients_and_indices in fg_clients.items():
else: out_node = out.owner
unfuseable_clients[out] = {client for client, _ in clients}
if not (
return fuseable_clients, unfuseable_clients out_node is not None
and len(out_node.outputs) == 1
def find_fuseable_subgraph( and isinstance(out_node.op, Elemwise)
*, and elemwise_scalar_op_has_c_code(out_node)
visited_nodes: set[Apply], ):
fuseable_clients: FUSEABLE_MAPPING, continue
unfuseable_clients: UNFUSEABLE_MAPPING,
ancestors_bitset: dict[Apply, int], candidate_starting_nodes.add(out_node)
toposort_index: dict[Apply, int], out_bcast = out.type.broadcastable
) -> tuple[list[Variable], list[Variable]]: out_fuseable_clients = {
for starting_node in toposort_index: client
if starting_node in visited_nodes: for client, _ in clients_and_indices
continue if (
len(client.outputs) == 1
and isinstance(client.op, Elemwise)
and out_bcast == client.outputs[0].type.broadcastable
and elemwise_scalar_op_has_c_code(client)
)
}
if out_fuseable_clients:
fuseable_clients[out_node] = out_fuseable_clients
if not candidate_starting_nodes:
return None
# To enable fast dependency queries, we create a bitset of ancestors for each node.
# Each node is first represented by a bit flag of it's position in the toposort
# This can be achieved with python integers, via 1 << toposort_idx (equivalent to slower 2 ** toposort_idx)
# The ancestors bitsets of each node are obtained by bitwise OR of the ancestor bitsets
# of each of the nodes' inputs, and the bit flag of the node itself.
#
# Example: With three variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c,
# the nodes bit flags would be {A: 0b001, B: 0b010, C: 0b100} (integers {A: 1, B: 2, C: 4})
# and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111} (integers {A: 1, B: 3, C: 7})
#
# This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND
# For example, to ask if A is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[A] != 0`
# We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())}
# Root variables have `None` as owner, which we can handle with a bitset of 0
ancestors_bitset = {None: 0}
for node, node_bitflag in nodes_bitflags.items():
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
ancestors_bitset[node] = reduce(
or_,
(ancestors_bitset[inp.owner] for inp in node.inputs),
node_bitflag,
)
# Handle root and leaf nodes gracefully
# We do it after the ancestors_bitset are built to simplify the previous loop.
# Root variables have `None` as owner, which we can handle with a bitflag of 0
nodes_bitflags[None] = 0
# Nothing ever depends on the special Output nodes, so just use a new bit for all of them
out_bitflag = 1 << len(nodes_bitflags)
for out in fg.outputs:
for client, _ in fg_clients[out]:
if isinstance(client.op, Output):
nodes_bitflags[client] = out_bitflag
# Start main loop to find collection of fuseable subgraphs
# We store the collection in `sorted_subgraphs`, in reverse topological order
sorted_subgraphs: list[
tuple[int, tuple[tuple[Variable], tuple[Variable]]]
] = []
# Keep a bitset of nodes that have been claimed by subgraphs
all_subgraphs_bitset = 0
# Start exploring in reverse topological order from candidate sink nodes
# Sink nodes, are nodes that don't have any potential fuseable clients
for starting_node, starting_bitflag in reversed(nodes_bitflags.items()):
if (
starting_bitflag & all_subgraphs_bitset
or starting_node not in candidate_starting_nodes
or starting_node in fuseable_clients
):
continue
starting_out = starting_node.outputs[0] # We use an ordered queue to control the direction in which we expand the subgraph
if not fuseable_clients.get(starting_out): # For simplicity, we always want to visit ancestors before clients
visited_nodes.add(starting_node) # For ancestors, we want to visit the later nodes first (those that have more dependencies)
# whereas for clients we want to visit earlier nodes first (those that have fewer dependencies)
# To achieve this we use the bitflag as the sorting key (which encodes the topological order)
# and negate it for ancestors.
fuseables_nodes_queue = [(-starting_bitflag, starting_node)]
heapify(fuseables_nodes_queue)
# We keep 3 bitsets during the exploration of a new subgraph:
# - the nodes that are part of the subgraph
# - the unfuseable ancestors of the subgraph (i.e., ancestors that are not fuseable with a node in the subgraph)
# - the unfuseable clients of the subgraph (i.e., clients that are not fuseable with a node in the subgraph)
# Whenever we visit a candidate node, we check if the subgraph's unfuseable ancestors depend on it,
# or if it depends on one of the subgraphs' unfuseable client, in which case we can't add it.
# If we can add it, we then add its unfuseable ancestors/clients to the respective bitsets
# and add its fuseable ancestors/clients to the queue to explore later.
# To work correctly, we must visit candidate subgraph nodes in the order described by the queue above.
# Otherwise, we would need to perform more complex dependency checks in every iteration and/or backtrack.
subgraph_nodes = []
subgraph_bitset = 0
unfuseable_ancestors_bitset = 0
unfuseable_clients_bitset = 0
while fuseables_nodes_queue:
node_bitflag, node = heappop(fuseables_nodes_queue)
is_ancestor = node_bitflag < 0
if is_ancestor:
node_bitflag = -node_bitflag
if node_bitflag & subgraph_bitset:
# Already part of the subgraph
continue continue
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set if is_ancestor:
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set if node_bitflag & unfuseable_ancestors_bitset:
subgraph_inputs_ancestors_bitset = 0 # An unfuseable ancestor of the subgraph depends on this node, can't fuse
unfuseable_clients_subgraph_bitset = 0 continue
elif ancestors_bitset[node] & unfuseable_clients_bitset:
# If we need to manipulate the maps in place, we'll do a shallow copy later # This node depends on an unfuseable client of the subgraph, can't fuse
# For now we query on the original ones continue
fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients)
unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients)
# We now try to expand as much as possible towards the potentially
# fuseable clients and ancestors to detect the largest possible
# subgraph that can be Composed together into a single `Op`. The
# largest issue to watch out is for cyclical dependencies, where
# some inputs or clients may depend on other nodes of the same
# subgraph via a path that cannot be included in the Composite
# (unfuseable)
fuseable_nodes_to_visit = deque([starting_node])
while fuseable_nodes_to_visit:
next_node = fuseable_nodes_to_visit.popleft()
visited_nodes.add(next_node)
next_out = next_node.outputs[0]
# If the output variable of next_node has no fuseable clients
# or has unfuseable clients, then next_node must become an output
# if it is to be fused.
must_become_output = not fuseable_clients_clone.get(
next_out
) or unfuseable_clients_clone.get(next_out)
# We have backtracked to this node, and it may no longer be a viable output,
# so we remove it and check again as if we had never seen this node
if must_become_output:
subgraph_outputs.pop(next_out, None)
# We need to check that any inputs required by this node
# do not depend on other outputs of the current subgraph,
# via an unfuseable path.
must_backtrack = (
ancestors_bitset[next_node]
& unfuseable_clients_subgraph_bitset
)
if not must_backtrack:
implied_unfuseable_clients_bitset = reduce(
or_,
(
1 << toposort_index[client]
for client in unfuseable_clients_clone.get(next_out)
if not isinstance(client.op, Output)
),
0,
)
# We need to check that any inputs of the current subgraph # Add node to subgraph
# do not depend on other clients of this node, subgraph_nodes.append(node)
# via an unfuseable path. subgraph_bitset |= node_bitflag
must_backtrack = (
subgraph_inputs_ancestors_bitset # Expand through ancestors and client nodes
& implied_unfuseable_clients_bitset # A node can either be:
# - already part of the subgraph (skip)
# - fuseable (add to queue)
# - unfuseable (add to respective unfuseable bitset)
for inp in node.inputs:
ancestor_node = inp.owner
ancestor_bitflag = nodes_bitflags[ancestor_node]
if ancestor_bitflag & subgraph_bitset:
continue
if node in fuseable_clients.get(ancestor_node, ()):
heappush(
fuseables_nodes_queue,
(-ancestor_bitflag, ancestor_node),
) )
else:
if must_backtrack: # If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
for inp in next_node.inputs: # nor with any of the ancestor's ancestors
if inp.owner in visited_nodes: unfuseable_ancestors_bitset |= ancestors_bitset[
if next_node not in fuseable_clients_clone[inp]: ancestor_node
# This can happen when next node has repeated inputs ]
continue
fuseable_clients_clone.remove_from_key( next_fuseable_clients = fuseable_clients.get(node, ())
inp, next_node for client, _ in fg_clients[node.outputs[0]]:
) client_bitflag = nodes_bitflags[client]
unfuseable_clients_clone.add_to_key(inp, next_node) if client_bitflag & subgraph_bitset:
# This input must become an output of the subgraph,
# because it can't be merged with next_node.
# We will revisit it to make sure this is safe.
fuseable_nodes_to_visit.appendleft(inp.owner)
# need to convert to tuple not to change set size during iteration
for client in tuple(fuseable_clients_clone[next_out]):
if client in visited_nodes:
fuseable_clients_clone.remove_from_key(
next_out, client
)
unfuseable_clients_clone.add_to_key(
next_out, client
)
# next_out must become an input of the subgraph.
# We will revisit any of its clients currently
# in the subgraph to make sure this is safe.
fuseable_nodes_to_visit.appendleft(client)
# Revisit node at a later time
visited_nodes.remove(next_node)
continue continue
if client in next_fuseable_clients:
heappush(fuseables_nodes_queue, (client_bitflag, client))
else:
# If a client is not in the node's fuseable clients set, it's nto fuseable with it,
# nor any of its clients. But we don't need to keep track of those as any downstream
# client we may consider later will also depend on this unfuseable client and be rejected
unfuseable_clients_bitset |= client_bitflag
# Adding next_node to subgraph does not result in any # Finished exploring this subgraph
# immediate dependency problems. Update subgraph all_subgraphs_bitset |= subgraph_bitset
# mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite
if must_become_output:
subgraph_outputs[next_out] = None
unfuseable_clients_subgraph_bitset |= (
implied_unfuseable_clients_bitset
)
for inp in sorted( if subgraph_bitset == starting_bitflag:
next_node.inputs, # We ended were we started, no fusion possible
key=lambda x: toposort_index.get(x.owner, -1), continue
):
if next_node in unfuseable_clients_clone.get(inp, ()):
# input must become an input of the subgraph since it's unfuseable with new node
subgraph_inputs_ancestors_bitset |= (
ancestors_bitset.get(inp.owner, 0)
)
subgraph_inputs[inp] = None
elif inp.owner not in visited_nodes:
fuseable_nodes_to_visit.appendleft(inp.owner)
# Expand through unvisited fuseable clients
fuseable_nodes_to_visit.extend(
sorted(
(
node
for node in fuseable_clients_clone.get(next_out)
if node not in visited_nodes
),
key=toposort_index.get, # type: ignore[arg-type]
)
)
# Don't return if final subgraph is just the original Elemwise
if len(subgraph_outputs) == 1 and set(
next(iter(subgraph_outputs)).owner.inputs
) == set(subgraph_inputs):
# Update global fuseable mappings
# No input was actually fuseable
for inp in starting_node.inputs:
fuseable_clients[inp].discard(starting_node)
unfuseable_clients[inp].add(starting_node)
# No client was actually fuseable
unfuseable_clients[starting_out].update(
fuseable_clients.pop(starting_out, ())
)
continue
return list(subgraph_inputs), list(subgraph_outputs) # Find out the actual inputs/outputs variables of the subgraph
raise ValueError not_subgraph_bitset = ~subgraph_bitset
# Inputs are variables whose nodes are not part of the subgraph (including root variables without nodes)
def update_fuseable_mappings_after_fg_replace( # Use a dict to deduplicate while preserving order
*, subgraph_inputs = tuple(
visited_nodes: set[Apply], dict.fromkeys(
fuseable_clients: FUSEABLE_MAPPING, inp
unfuseable_clients: UNFUSEABLE_MAPPING, for node in subgraph_nodes
toposort_index: dict[Apply, int], for inp in node.inputs
ancestors_bitset: dict[Apply, int], if (inp_node := inp.owner) is None
starting_nodes: set[Apply], or nodes_bitflags[inp_node] & not_subgraph_bitset
updated_nodes: set[Apply],
) -> None:
# Find new composite node and dropped intermediate nodes
# by comparing the current fg.apply nodes with the cached
# original nodes
(new_composite_node,) = updated_nodes - starting_nodes
dropped_nodes = starting_nodes - updated_nodes
# Remove intermediate Composite nodes from mappings
# And compute the ancestors bitset of the new composite node
# As well as the new toposort index for the new node
new_node_ancestor_bitset = 0
new_node_toposort_index = len(toposort_index)
for dropped_node in dropped_nodes:
(dropped_out,) = dropped_node.outputs
fuseable_clients.pop(dropped_out, None)
unfuseable_clients.pop(dropped_out, None)
visited_nodes.remove(dropped_node)
# The new composite ancestor bitset is the union
# of the ancestors of all the dropped nodes
new_node_ancestor_bitset |= ancestors_bitset[dropped_node]
# The new composite node can have the same order as the latest node that was absorbed into it
new_node_toposort_index = max(
new_node_toposort_index, toposort_index[dropped_node]
) )
)
# Outputs are variables with client nodes that are not part of the subgraph (including special fgraph output nodes)
# Outputs are unique, no need to deduplicate
subgraph_outputs = tuple(
node.outputs[0]
for node in subgraph_nodes
if any(
nodes_bitflags[client] & not_subgraph_bitset
for client, _ in fg_clients[node.outputs[0]]
)
)
ancestors_bitset[new_composite_node] = new_node_ancestor_bitset # Update fuseable clients mapping for subgraph inputs and outputs
toposort_index[new_composite_node] = new_node_toposort_index # Inputs cannot be fused with nodes in the subgraph
# Update fuseable information for subgraph inputs
for inp in subgraph_inputs: for inp in subgraph_inputs:
if inp in fuseable_clients: if (inp_node := inp.owner) is not None and (
new_fuseable_clients = { inp_fuseable_clients := fuseable_clients.get(inp_node)
client ):
for client in fuseable_clients[inp] inp_fuseable_clients.difference_update(subgraph_nodes)
if client not in dropped_nodes # If there are no fuseable_clients left for this input delete it's entry
} if not inp_fuseable_clients:
if new_fuseable_clients: del fuseable_clients[inp_node]
fuseable_clients[inp] = new_fuseable_clients # Outputs cannot be fused with anything else
else: for out in subgraph_outputs:
fuseable_clients.pop(inp) fuseable_clients.pop(out.owner, None)
unfuseable_clients[inp] = (
unfuseable_clients[inp] - dropped_nodes # Add new subgraph to sorted_subgraphs
) | {new_composite_node} # Because we start from sink nodes in reverse topological order, most times new subgraphs
# don't depend on previous subgraphs, so we can just append them at the end.
# Update fuseable information for subgraph outputs if not (unfuseable_ancestors_bitset & all_subgraphs_bitset):
for out in new_composite_node.outputs: # That's the case here
unfuseable_clients[out] = {client for client, _ in fg.clients[out]} # None of the unfuseable_ancestors (i.e, the ancestors) are present in the previous collected subgraphs
sorted_subgraphs.append(
visited_nodes.add(new_composite_node) (subgraph_bitset, (subgraph_inputs, subgraph_outputs))
return
# We start by creating two maps, 1) from each node to each potentially
# fuseable client (both nodes must be single output Elemwise with same
# broadcast type) and 2) from each node to each certainly unfuseable
# client (those that don't fit into 1))
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
visited_nodes: set[Apply] = set()
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
# Create a bitset for each node of all its ancestors
# This allows to quickly check if a variable depends on a set
ancestors_bitset: dict[Apply, int] = {}
for node, index in toposort_index.items():
node_ancestor_bitset = 1 << index
for inp in node.inputs:
if (inp_node := inp.owner) is not None:
node_ancestor_bitset |= ancestors_bitset[inp_node]
ancestors_bitset[node] = node_ancestor_bitset
while True:
try:
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
ancestors_bitset=ancestors_bitset,
toposort_index=toposort_index,
) )
except ValueError:
return
else: else:
# The caller is now expected to update fg in place, # But not here, so we need to find the right position for insertion.
# by replacing the subgraph with a Composite Op # We iterate through the previous subgraphs in topological order (reverse of the stored order).
starting_nodes = fg.apply_nodes.copy() # We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again.
# The (index + 1) of the firs iteration where the check passes is the correct insertion position.
yield subgraph_inputs, subgraph_outputs remaining_subgraphs_bitset = all_subgraphs_bitset
for index, (other_subgraph_bitset, _) in enumerate(
# This is where we avoid repeated work by using a stateful reversed(sorted_subgraphs)
# generator. For large models (as in `TestFusion.test_big_fusion`) ):
# this can provide huge speedups # Exclude subgraph bitset
update_fuseable_mappings_after_fg_replace( remaining_subgraphs_bitset &= ~other_subgraph_bitset
visited_nodes=visited_nodes, if not (
fuseable_clients=fuseable_clients, unfuseable_ancestors_bitset & remaining_subgraphs_bitset
unfuseable_clients=unfuseable_clients, ):
toposort_index=toposort_index, break # bingo
ancestors_bitset=ancestors_bitset, sorted_subgraphs.insert(
starting_nodes=starting_nodes, -(index + 1),
updated_nodes=fg.apply_nodes, (subgraph_bitset, (subgraph_inputs, subgraph_outputs)),
) )
# yield from sorted_subgraphs, discarding the subgraph_bitset
yield from (io for _, io in sorted_subgraphs)
max_operands = elemwise_max_operands_fct(None) max_operands = elemwise_max_operands_fct(None)
reason = self.__class__.__name__ reason = self.__class__.__name__
nb_fused = 0 nb_fused = 0
nb_replacement = 0 nb_replacement = 0
for inputs, outputs in find_next_fuseable_subgraph(fgraph): for inputs, outputs in find_fuseable_subgraphs(fgraph):
if (len(inputs) + len(outputs)) > max_operands: if (len(inputs) + len(outputs)) > max_operands:
warn( warn(
"Loop fusion failed because the resulting node would exceed " "Loop fusion failed because the resulting node would exceed the kernel argument limit."
"the kernel argument limit."
) )
break continue
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
composite_outputs = Elemwise( composite_outputs = Elemwise(
...@@ -955,7 +863,8 @@ class FusionOptimizer(GraphRewriter): ...@@ -955,7 +863,8 @@ class FusionOptimizer(GraphRewriter):
starting_nodes = len(fgraph.apply_nodes) starting_nodes = len(fgraph.apply_nodes)
fgraph.replace_all_validate( fgraph.replace_all_validate(
tuple(zip(outputs, composite_outputs)), reason=reason tuple(zip(outputs, composite_outputs)),
reason=reason,
) )
nb_fused += 1 nb_fused += 1
nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1 nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1
......
...@@ -319,6 +319,26 @@ class TestFusion: ...@@ -319,6 +319,26 @@ class TestFusion:
assert nb_fused == 1 assert nb_fused == 1
assert nb_replacement == 4 assert nb_replacement == 4
def test_expansion_order(self):
# This test is designed to fail if we don't use the right expansion order in the current implementation
# It may be considered irrelevant if the algorithm changes and this is no longer a concern.
# In that case the test can be tweaked or removed
a = pt.vector("a")
b = pt.exp(a)
# Unique creates an unfuesable path between b and d/e
c = pt.unique(b)
d = pt.log(c)
# The critical aspect of the current implementation, is that we must visit d before c,
# so we learn about the unfuseable path by the time we visit c
e1 = b + d
e2 = d + b # test both orders
fg = FunctionGraph([a], [e1, e2], clone=False)
_, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg)
fg.dprint()
assert nb_fused == 1
assert nb_replacement == 3
@pytest.mark.parametrize( @pytest.mark.parametrize(
"case", "case",
[ [
...@@ -1374,7 +1394,7 @@ class TestFusion: ...@@ -1374,7 +1394,7 @@ class TestFusion:
"graph_fn, n, expected_n_repl", "graph_fn, n, expected_n_repl",
[ [
("deep_small_kernels", 20, (20, 60)), ("deep_small_kernels", 20, (20, 60)),
("large_fuseable_graph", 25, (103, 876)), ("large_fuseable_graph", 25, (128, 876)),
], ],
) )
def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论