提交 d26374cd authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Avoid backtracking in FusionOptimizer

The change in number of fused kernels has to do with the order of iteration, and could be replicated in the old approach by iterating in topological order. It was an accident that it happen to visit in an order where it connected two branches, instead of keeping them separate. The underlying limitation already existed and is described in https://github.com/pymc-devs/pytensor/issues/249
上级 566af64d
......@@ -2,12 +2,10 @@ import abc
import itertools
import operator
import sys
import typing
from collections import defaultdict, deque
from collections.abc import Generator, Sequence
from functools import cache, reduce
from heapq import heapify, heappop, heappush
from operator import or_
from typing import Literal
from warnings import warn
import pytensor.scalar.basic as ps
......@@ -524,43 +522,6 @@ def elemwise_max_operands_fct(node) -> int:
return 1024
class CopyOnWriteDictOfSets:
__slots__ = ("d", "d_copy")
def __init__(self, d: dict[typing.Any, set]):
self.d = d
self.d_copy: dict[typing.Any, set] = {}
def __getitem__(self, key):
try:
return self.d_copy[key]
except KeyError:
return self.d[key]
def get(self, key, default=frozenset()):
try:
return self.d_copy[key]
except KeyError:
try:
return self.d[key]
except KeyError:
return default
def remove_from_key(self, key, value):
try:
self.d_copy[key].remove(value)
except KeyError:
self.d_copy[key] = copied_value = self.d[key].copy()
copied_value.remove(value)
def add_to_key(self, key, value):
try:
self.d_copy[key].add(value)
except KeyError:
self.d_copy[key] = copied_value = self.d[key].copy()
copied_value.add(value)
class FusionOptimizer(GraphRewriter):
"""Graph optimizer that fuses consecutive Elemwise operations."""
......@@ -594,353 +555,300 @@ class FusionOptimizer(GraphRewriter):
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
def find_next_fuseable_subgraph(
def find_fuseable_subgraphs(
fg: FunctionGraph,
) -> Generator[tuple[list[Variable], list[Variable]], None, None]:
"""Find all subgraphs in a FunctionGraph that can be fused together
Yields
-------
List of inputs and outputs that determine subgraphs which can be fused.
This generator assumes that such subgraph is replaced by a single
Elemwise Composite before being accessed again in the next iteration.
) -> Generator[tuple[tuple[Variable], tuple[Variable]], None, None]:
"""Find subgraphs of Elemwise nodes that can be fused together.
In general, there is no single solution. We try to find large subgraphs eagerly
Any two consecutive Elemwise nodes that have the same broadcasting pattern,
and a C-implementation (historical accident that should be revisited), are potentially fuseable.
However, not all collections of fuseable pairs make a valid fused subgraph.
A valid fused subgraph must be "convex", meaning that no two nodes in the subgraph
are connected via a path that goes outside the subgraph, either because they
are connected via unfuseable nodes, or nodes that have been claimed by another fused subgraph.
For example the subgraph add(sin(exp(x)), sum(exp(x)) cannot be fused together,
because the sum node breaks the convexity of the subgraph {exp, sin, add}.
However, we can fuse {exp, sin}, and perhaps fuse add with something else.
This function yields subgraph in reverse topological order so they can be safely replaced one at a time
"""
FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
def initialize_fuseable_mappings(
*, fg: FunctionGraph
) -> tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]:
@cache
def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
def elemwise_scalar_op_has_c_code(
node: Apply, optimizer_verbose=config.optimizer_verbose
) -> bool:
# TODO: This should not play a role in non-c backends!
if node.op.scalar_op.supports_c_code(node.inputs, node.outputs):
return True
else:
if config.optimizer_verbose:
elif optimizer_verbose:
warn(
f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation."
)
return False
# Fuseable nodes have to be accessed in a deterministic manner
# to ensure the rewrite remains deterministic.
# This is not a problem from unfuseable ones, as they can never
# become part of the graph.
fuseable_clients: FUSEABLE_MAPPING = defaultdict(set)
unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set)
for out, clients in fg.clients.items():
out_maybe_fuseable = (
out.owner is not None
and isinstance(out.owner.op, Elemwise)
# and not isinstance(out.owner.op.scalar_op, ps.Composite)
and len(out.owner.outputs) == 1
and elemwise_scalar_op_has_c_code(out.owner)
)
if out_maybe_fuseable:
# Create a map from node to a set of fuseable client (successor) nodes
# A node and a client are fuseable if they are both single output Elemwise
# (with C-implementation) and have the same output broadcastable pattern
# Nodes that have no fuseable clients are not included
fuseable_clients: dict[Apply, set[Apply]] = {}
# We also create a set with candidate nodes from which to start a subgraph expansion
# These are Single output Elemwise nodes (with C-implementation) that may or not
# have fuseable ancestors/clients at the start.
candidate_starting_nodes = set()
fg_clients = fg.clients
for out, clients_and_indices in fg_clients.items():
out_node = out.owner
if not (
out_node is not None
and len(out_node.outputs) == 1
and isinstance(out_node.op, Elemwise)
and elemwise_scalar_op_has_c_code(out_node)
):
continue
candidate_starting_nodes.add(out_node)
out_bcast = out.type.broadcastable
for client, _ in clients:
out_fuseable_clients = {
client
for client, _ in clients_and_indices
if (
isinstance(client.op, Elemwise)
# and not isinstance(client.op.scalar_op, ps.Composite)
and len(client.outputs) == 1
len(client.outputs) == 1
and isinstance(client.op, Elemwise)
and out_bcast == client.outputs[0].type.broadcastable
and elemwise_scalar_op_has_c_code(client)
)
}
if out_fuseable_clients:
fuseable_clients[out_node] = out_fuseable_clients
if not candidate_starting_nodes:
return None
# To enable fast dependency queries, we create a bitset of ancestors for each node.
# Each node is first represented by a bit flag of it's position in the toposort
# This can be achieved with python integers, via 1 << toposort_idx (equivalent to slower 2 ** toposort_idx)
# The ancestors bitsets of each node are obtained by bitwise OR of the ancestor bitsets
# of each of the nodes' inputs, and the bit flag of the node itself.
#
# Example: With three variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c,
# the nodes bit flags would be {A: 0b001, B: 0b010, C: 0b100} (integers {A: 1, B: 2, C: 4})
# and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111} (integers {A: 1, B: 3, C: 7})
#
# This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND
# For example, to ask if A is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[A] != 0`
# We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())}
# Root variables have `None` as owner, which we can handle with a bitset of 0
ancestors_bitset = {None: 0}
for node, node_bitflag in nodes_bitflags.items():
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
ancestors_bitset[node] = reduce(
or_,
(ancestors_bitset[inp.owner] for inp in node.inputs),
node_bitflag,
)
# Handle root and leaf nodes gracefully
# We do it after the ancestors_bitset are built to simplify the previous loop.
# Root variables have `None` as owner, which we can handle with a bitflag of 0
nodes_bitflags[None] = 0
# Nothing ever depends on the special Output nodes, so just use a new bit for all of them
out_bitflag = 1 << len(nodes_bitflags)
for out in fg.outputs:
for client, _ in fg_clients[out]:
if isinstance(client.op, Output):
nodes_bitflags[client] = out_bitflag
# Start main loop to find collection of fuseable subgraphs
# We store the collection in `sorted_subgraphs`, in reverse topological order
sorted_subgraphs: list[
tuple[int, tuple[tuple[Variable], tuple[Variable]]]
] = []
# Keep a bitset of nodes that have been claimed by subgraphs
all_subgraphs_bitset = 0
# Start exploring in reverse topological order from candidate sink nodes
# Sink nodes, are nodes that don't have any potential fuseable clients
for starting_node, starting_bitflag in reversed(nodes_bitflags.items()):
if (
starting_bitflag & all_subgraphs_bitset
or starting_node not in candidate_starting_nodes
or starting_node in fuseable_clients
):
fuseable_clients[out].add(client)
else:
unfuseable_clients[out].add(client)
else:
unfuseable_clients[out] = {client for client, _ in clients}
return fuseable_clients, unfuseable_clients
def find_fuseable_subgraph(
*,
visited_nodes: set[Apply],
fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING,
ancestors_bitset: dict[Apply, int],
toposort_index: dict[Apply, int],
) -> tuple[list[Variable], list[Variable]]:
for starting_node in toposort_index:
if starting_node in visited_nodes:
continue
starting_out = starting_node.outputs[0]
if not fuseable_clients.get(starting_out):
visited_nodes.add(starting_node)
# We use an ordered queue to control the direction in which we expand the subgraph
# For simplicity, we always want to visit ancestors before clients
# For ancestors, we want to visit the later nodes first (those that have more dependencies)
# whereas for clients we want to visit earlier nodes first (those that have fewer dependencies)
# To achieve this we use the bitflag as the sorting key (which encodes the topological order)
# and negate it for ancestors.
fuseables_nodes_queue = [(-starting_bitflag, starting_node)]
heapify(fuseables_nodes_queue)
# We keep 3 bitsets during the exploration of a new subgraph:
# - the nodes that are part of the subgraph
# - the unfuseable ancestors of the subgraph (i.e., ancestors that are not fuseable with a node in the subgraph)
# - the unfuseable clients of the subgraph (i.e., clients that are not fuseable with a node in the subgraph)
# Whenever we visit a candidate node, we check if the subgraph's unfuseable ancestors depend on it,
# or if it depends on one of the subgraphs' unfuseable client, in which case we can't add it.
# If we can add it, we then add its unfuseable ancestors/clients to the respective bitsets
# and add its fuseable ancestors/clients to the queue to explore later.
# To work correctly, we must visit candidate subgraph nodes in the order described by the queue above.
# Otherwise, we would need to perform more complex dependency checks in every iteration and/or backtrack.
subgraph_nodes = []
subgraph_bitset = 0
unfuseable_ancestors_bitset = 0
unfuseable_clients_bitset = 0
while fuseables_nodes_queue:
node_bitflag, node = heappop(fuseables_nodes_queue)
is_ancestor = node_bitflag < 0
if is_ancestor:
node_bitflag = -node_bitflag
if node_bitflag & subgraph_bitset:
# Already part of the subgraph
continue
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
subgraph_inputs_ancestors_bitset = 0
unfuseable_clients_subgraph_bitset = 0
# If we need to manipulate the maps in place, we'll do a shallow copy later
# For now we query on the original ones
fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients)
unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients)
# We now try to expand as much as possible towards the potentially
# fuseable clients and ancestors to detect the largest possible
# subgraph that can be Composed together into a single `Op`. The
# largest issue to watch out is for cyclical dependencies, where
# some inputs or clients may depend on other nodes of the same
# subgraph via a path that cannot be included in the Composite
# (unfuseable)
fuseable_nodes_to_visit = deque([starting_node])
while fuseable_nodes_to_visit:
next_node = fuseable_nodes_to_visit.popleft()
visited_nodes.add(next_node)
next_out = next_node.outputs[0]
# If the output variable of next_node has no fuseable clients
# or has unfuseable clients, then next_node must become an output
# if it is to be fused.
must_become_output = not fuseable_clients_clone.get(
next_out
) or unfuseable_clients_clone.get(next_out)
# We have backtracked to this node, and it may no longer be a viable output,
# so we remove it and check again as if we had never seen this node
if must_become_output:
subgraph_outputs.pop(next_out, None)
# We need to check that any inputs required by this node
# do not depend on other outputs of the current subgraph,
# via an unfuseable path.
must_backtrack = (
ancestors_bitset[next_node]
& unfuseable_clients_subgraph_bitset
)
if is_ancestor:
if node_bitflag & unfuseable_ancestors_bitset:
# An unfuseable ancestor of the subgraph depends on this node, can't fuse
continue
elif ancestors_bitset[node] & unfuseable_clients_bitset:
# This node depends on an unfuseable client of the subgraph, can't fuse
continue
if not must_backtrack:
implied_unfuseable_clients_bitset = reduce(
or_,
(
1 << toposort_index[client]
for client in unfuseable_clients_clone.get(next_out)
if not isinstance(client.op, Output)
),
0,
)
# Add node to subgraph
subgraph_nodes.append(node)
subgraph_bitset |= node_bitflag
# We need to check that any inputs of the current subgraph
# do not depend on other clients of this node,
# via an unfuseable path.
must_backtrack = (
subgraph_inputs_ancestors_bitset
& implied_unfuseable_clients_bitset
# Expand through ancestors and client nodes
# A node can either be:
# - already part of the subgraph (skip)
# - fuseable (add to queue)
# - unfuseable (add to respective unfuseable bitset)
for inp in node.inputs:
ancestor_node = inp.owner
ancestor_bitflag = nodes_bitflags[ancestor_node]
if ancestor_bitflag & subgraph_bitset:
continue
if node in fuseable_clients.get(ancestor_node, ()):
heappush(
fuseables_nodes_queue,
(-ancestor_bitflag, ancestor_node),
)
else:
# If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
# nor with any of the ancestor's ancestors
unfuseable_ancestors_bitset |= ancestors_bitset[
ancestor_node
]
if must_backtrack:
for inp in next_node.inputs:
if inp.owner in visited_nodes:
if next_node not in fuseable_clients_clone[inp]:
# This can happen when next node has repeated inputs
next_fuseable_clients = fuseable_clients.get(node, ())
for client, _ in fg_clients[node.outputs[0]]:
client_bitflag = nodes_bitflags[client]
if client_bitflag & subgraph_bitset:
continue
fuseable_clients_clone.remove_from_key(
inp, next_node
)
unfuseable_clients_clone.add_to_key(inp, next_node)
# This input must become an output of the subgraph,
# because it can't be merged with next_node.
# We will revisit it to make sure this is safe.
fuseable_nodes_to_visit.appendleft(inp.owner)
# need to convert to tuple not to change set size during iteration
for client in tuple(fuseable_clients_clone[next_out]):
if client in visited_nodes:
fuseable_clients_clone.remove_from_key(
next_out, client
)
unfuseable_clients_clone.add_to_key(
next_out, client
)
if client in next_fuseable_clients:
heappush(fuseables_nodes_queue, (client_bitflag, client))
else:
# If a client is not in the node's fuseable clients set, it's nto fuseable with it,
# nor any of its clients. But we don't need to keep track of those as any downstream
# client we may consider later will also depend on this unfuseable client and be rejected
unfuseable_clients_bitset |= client_bitflag
# next_out must become an input of the subgraph.
# We will revisit any of its clients currently
# in the subgraph to make sure this is safe.
fuseable_nodes_to_visit.appendleft(client)
# Finished exploring this subgraph
all_subgraphs_bitset |= subgraph_bitset
# Revisit node at a later time
visited_nodes.remove(next_node)
if subgraph_bitset == starting_bitflag:
# We ended were we started, no fusion possible
continue
# Adding next_node to subgraph does not result in any
# immediate dependency problems. Update subgraph
# mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite
if must_become_output:
subgraph_outputs[next_out] = None
unfuseable_clients_subgraph_bitset |= (
implied_unfuseable_clients_bitset
)
for inp in sorted(
next_node.inputs,
key=lambda x: toposort_index.get(x.owner, -1),
):
if next_node in unfuseable_clients_clone.get(inp, ()):
# input must become an input of the subgraph since it's unfuseable with new node
subgraph_inputs_ancestors_bitset |= (
ancestors_bitset.get(inp.owner, 0)
)
subgraph_inputs[inp] = None
elif inp.owner not in visited_nodes:
fuseable_nodes_to_visit.appendleft(inp.owner)
# Expand through unvisited fuseable clients
fuseable_nodes_to_visit.extend(
sorted(
(
node
for node in fuseable_clients_clone.get(next_out)
if node not in visited_nodes
),
key=toposort_index.get, # type: ignore[arg-type]
# Find out the actual inputs/outputs variables of the subgraph
not_subgraph_bitset = ~subgraph_bitset
# Inputs are variables whose nodes are not part of the subgraph (including root variables without nodes)
# Use a dict to deduplicate while preserving order
subgraph_inputs = tuple(
dict.fromkeys(
inp
for node in subgraph_nodes
for inp in node.inputs
if (inp_node := inp.owner) is None
or nodes_bitflags[inp_node] & not_subgraph_bitset
)
)
# Don't return if final subgraph is just the original Elemwise
if len(subgraph_outputs) == 1 and set(
next(iter(subgraph_outputs)).owner.inputs
) == set(subgraph_inputs):
# Update global fuseable mappings
# No input was actually fuseable
for inp in starting_node.inputs:
fuseable_clients[inp].discard(starting_node)
unfuseable_clients[inp].add(starting_node)
# No client was actually fuseable
unfuseable_clients[starting_out].update(
fuseable_clients.pop(starting_out, ())
# Outputs are variables with client nodes that are not part of the subgraph (including special fgraph output nodes)
# Outputs are unique, no need to deduplicate
subgraph_outputs = tuple(
node.outputs[0]
for node in subgraph_nodes
if any(
nodes_bitflags[client] & not_subgraph_bitset
for client, _ in fg_clients[node.outputs[0]]
)
continue
return list(subgraph_inputs), list(subgraph_outputs)
raise ValueError
def update_fuseable_mappings_after_fg_replace(
*,
visited_nodes: set[Apply],
fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING,
toposort_index: dict[Apply, int],
ancestors_bitset: dict[Apply, int],
starting_nodes: set[Apply],
updated_nodes: set[Apply],
) -> None:
# Find new composite node and dropped intermediate nodes
# by comparing the current fg.apply nodes with the cached
# original nodes
(new_composite_node,) = updated_nodes - starting_nodes
dropped_nodes = starting_nodes - updated_nodes
# Remove intermediate Composite nodes from mappings
# And compute the ancestors bitset of the new composite node
# As well as the new toposort index for the new node
new_node_ancestor_bitset = 0
new_node_toposort_index = len(toposort_index)
for dropped_node in dropped_nodes:
(dropped_out,) = dropped_node.outputs
fuseable_clients.pop(dropped_out, None)
unfuseable_clients.pop(dropped_out, None)
visited_nodes.remove(dropped_node)
# The new composite ancestor bitset is the union
# of the ancestors of all the dropped nodes
new_node_ancestor_bitset |= ancestors_bitset[dropped_node]
# The new composite node can have the same order as the latest node that was absorbed into it
new_node_toposort_index = max(
new_node_toposort_index, toposort_index[dropped_node]
)
ancestors_bitset[new_composite_node] = new_node_ancestor_bitset
toposort_index[new_composite_node] = new_node_toposort_index
# Update fuseable information for subgraph inputs
# Update fuseable clients mapping for subgraph inputs and outputs
# Inputs cannot be fused with nodes in the subgraph
for inp in subgraph_inputs:
if inp in fuseable_clients:
new_fuseable_clients = {
client
for client in fuseable_clients[inp]
if client not in dropped_nodes
}
if new_fuseable_clients:
fuseable_clients[inp] = new_fuseable_clients
else:
fuseable_clients.pop(inp)
unfuseable_clients[inp] = (
unfuseable_clients[inp] - dropped_nodes
) | {new_composite_node}
# Update fuseable information for subgraph outputs
for out in new_composite_node.outputs:
unfuseable_clients[out] = {client for client, _ in fg.clients[out]}
visited_nodes.add(new_composite_node)
return
# We start by creating two maps, 1) from each node to each potentially
# fuseable client (both nodes must be single output Elemwise with same
# broadcast type) and 2) from each node to each certainly unfuseable
# client (those that don't fit into 1))
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
visited_nodes: set[Apply] = set()
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
# Create a bitset for each node of all its ancestors
# This allows to quickly check if a variable depends on a set
ancestors_bitset: dict[Apply, int] = {}
for node, index in toposort_index.items():
node_ancestor_bitset = 1 << index
for inp in node.inputs:
if (inp_node := inp.owner) is not None:
node_ancestor_bitset |= ancestors_bitset[inp_node]
ancestors_bitset[node] = node_ancestor_bitset
while True:
try:
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
ancestors_bitset=ancestors_bitset,
toposort_index=toposort_index,
if (inp_node := inp.owner) is not None and (
inp_fuseable_clients := fuseable_clients.get(inp_node)
):
inp_fuseable_clients.difference_update(subgraph_nodes)
# If there are no fuseable_clients left for this input delete it's entry
if not inp_fuseable_clients:
del fuseable_clients[inp_node]
# Outputs cannot be fused with anything else
for out in subgraph_outputs:
fuseable_clients.pop(out.owner, None)
# Add new subgraph to sorted_subgraphs
# Because we start from sink nodes in reverse topological order, most times new subgraphs
# don't depend on previous subgraphs, so we can just append them at the end.
if not (unfuseable_ancestors_bitset & all_subgraphs_bitset):
# That's the case here
# None of the unfuseable_ancestors (i.e, the ancestors) are present in the previous collected subgraphs
sorted_subgraphs.append(
(subgraph_bitset, (subgraph_inputs, subgraph_outputs))
)
except ValueError:
return
else:
# The caller is now expected to update fg in place,
# by replacing the subgraph with a Composite Op
starting_nodes = fg.apply_nodes.copy()
yield subgraph_inputs, subgraph_outputs
# This is where we avoid repeated work by using a stateful
# generator. For large models (as in `TestFusion.test_big_fusion`)
# this can provide huge speedups
update_fuseable_mappings_after_fg_replace(
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
toposort_index=toposort_index,
ancestors_bitset=ancestors_bitset,
starting_nodes=starting_nodes,
updated_nodes=fg.apply_nodes,
# But not here, so we need to find the right position for insertion.
# We iterate through the previous subgraphs in topological order (reverse of the stored order).
# We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again.
# The (index + 1) of the firs iteration where the check passes is the correct insertion position.
remaining_subgraphs_bitset = all_subgraphs_bitset
for index, (other_subgraph_bitset, _) in enumerate(
reversed(sorted_subgraphs)
):
# Exclude subgraph bitset
remaining_subgraphs_bitset &= ~other_subgraph_bitset
if not (
unfuseable_ancestors_bitset & remaining_subgraphs_bitset
):
break # bingo
sorted_subgraphs.insert(
-(index + 1),
(subgraph_bitset, (subgraph_inputs, subgraph_outputs)),
)
# yield from sorted_subgraphs, discarding the subgraph_bitset
yield from (io for _, io in sorted_subgraphs)
max_operands = elemwise_max_operands_fct(None)
reason = self.__class__.__name__
nb_fused = 0
nb_replacement = 0
for inputs, outputs in find_next_fuseable_subgraph(fgraph):
for inputs, outputs in find_fuseable_subgraphs(fgraph):
if (len(inputs) + len(outputs)) > max_operands:
warn(
"Loop fusion failed because the resulting node would exceed "
"the kernel argument limit."
"Loop fusion failed because the resulting node would exceed the kernel argument limit."
)
break
continue
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
composite_outputs = Elemwise(
......@@ -955,7 +863,8 @@ class FusionOptimizer(GraphRewriter):
starting_nodes = len(fgraph.apply_nodes)
fgraph.replace_all_validate(
tuple(zip(outputs, composite_outputs)), reason=reason
tuple(zip(outputs, composite_outputs)),
reason=reason,
)
nb_fused += 1
nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1
......
......@@ -319,6 +319,26 @@ class TestFusion:
assert nb_fused == 1
assert nb_replacement == 4
def test_expansion_order(self):
# This test is designed to fail if we don't use the right expansion order in the current implementation
# It may be considered irrelevant if the algorithm changes and this is no longer a concern.
# In that case the test can be tweaked or removed
a = pt.vector("a")
b = pt.exp(a)
# Unique creates an unfuesable path between b and d/e
c = pt.unique(b)
d = pt.log(c)
# The critical aspect of the current implementation, is that we must visit d before c,
# so we learn about the unfuseable path by the time we visit c
e1 = b + d
e2 = d + b # test both orders
fg = FunctionGraph([a], [e1, e2], clone=False)
_, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg)
fg.dprint()
assert nb_fused == 1
assert nb_replacement == 3
@pytest.mark.parametrize(
"case",
[
......@@ -1374,7 +1394,7 @@ class TestFusion:
"graph_fn, n, expected_n_repl",
[
("deep_small_kernels", 20, (20, 60)),
("large_fuseable_graph", 25, (103, 876)),
("large_fuseable_graph", 25, (128, 876)),
],
)
def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论