提交 b6a9b68e authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix FusionOptimizer bug

When a subgraph with multiple outputs is "implicitly" claimed, it can change the dependencies of remaining nodes. A node that depended only on a subset of the subgraph outputs now depends on all of them. Not taking this into account could lead to circular dependent Composites
上级 b68c74dc
...@@ -652,7 +652,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -652,7 +652,7 @@ class FusionOptimizer(GraphRewriter):
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())}
# Root variables have `None` as owner, which we can handle with a bitset of 0 # Root variables have `None` as owner, which we can handle with a bitset of 0
ancestors_bitsets = {None: 0} ancestors_bitsets: dict[Apply | None, int] = {None: 0}
for node, node_bitflag in nodes_bitflags.items(): for node, node_bitflag in nodes_bitflags.items():
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag # The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
ancestors_bitsets[node] = reduce( ancestors_bitsets[node] = reduce(
...@@ -694,9 +694,13 @@ class FusionOptimizer(GraphRewriter): ...@@ -694,9 +694,13 @@ class FusionOptimizer(GraphRewriter):
# For simplicity, we always want to visit ancestors before clients # For simplicity, we always want to visit ancestors before clients
# For ancestors, we want to visit the later nodes first (those that have more dependencies) # For ancestors, we want to visit the later nodes first (those that have more dependencies)
# whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies)
# To achieve this we use the bitflag as the sorting key (which encodes the topological order) # To achieve this we use the ancestors_bitset as the sorting key (which encodes the topological order)
# and negate it for ancestors. # and negate it for ancestors. We use the ancestors_bitset instead of the node bitflag because we
fuseables_nodes_queue = [(-starting_bitflag, starting_node)] # update the former when we find a fuseable subgraph, emulating the effect of recomputing the
# topological order on the remaining nodes.
fuseables_nodes_queue = [
(-ancestors_bitsets[starting_node], starting_bitflag, starting_node)
]
heapify(fuseables_nodes_queue) heapify(fuseables_nodes_queue)
# We keep 3 bitsets during the exploration of a new subgraph: # We keep 3 bitsets during the exploration of a new subgraph:
...@@ -715,10 +719,12 @@ class FusionOptimizer(GraphRewriter): ...@@ -715,10 +719,12 @@ class FusionOptimizer(GraphRewriter):
unfuseable_clients_bitset = 0 unfuseable_clients_bitset = 0
while fuseables_nodes_queue: while fuseables_nodes_queue:
node_bitflag, node = heappop(fuseables_nodes_queue) node_ancestors_bitset, node_bitflag, node = heappop(
is_ancestor = node_bitflag < 0 fuseables_nodes_queue
)
is_ancestor = node_ancestors_bitset < 0
if is_ancestor: if is_ancestor:
node_bitflag = -node_bitflag node_ancestors_bitset = -node_ancestors_bitset
if node_bitflag & subgraph_bitset: if node_bitflag & subgraph_bitset:
# Already part of the subgraph # Already part of the subgraph
...@@ -728,7 +734,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -728,7 +734,7 @@ class FusionOptimizer(GraphRewriter):
if node_bitflag & unfuseable_ancestors_bitset: if node_bitflag & unfuseable_ancestors_bitset:
# An unfuseable ancestor of the subgraph depends on this node, can't fuse # An unfuseable ancestor of the subgraph depends on this node, can't fuse
continue continue
elif ancestors_bitsets[node] & unfuseable_clients_bitset: elif node_ancestors_bitset & unfuseable_clients_bitset:
# This node depends on an unfuseable client of the subgraph, can't fuse # This node depends on an unfuseable client of the subgraph, can't fuse
continue continue
...@@ -749,7 +755,11 @@ class FusionOptimizer(GraphRewriter): ...@@ -749,7 +755,11 @@ class FusionOptimizer(GraphRewriter):
if node in fuseable_clients.get(ancestor_node, ()): if node in fuseable_clients.get(ancestor_node, ()):
heappush( heappush(
fuseables_nodes_queue, fuseables_nodes_queue,
(-ancestor_bitflag, ancestor_node), (
-ancestors_bitsets[ancestor_node],
ancestor_bitflag,
ancestor_node,
),
) )
else: else:
# If the node is not in the ancestor's fuseable clients set, it's not fuseable with it, # If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
...@@ -764,16 +774,17 @@ class FusionOptimizer(GraphRewriter): ...@@ -764,16 +774,17 @@ class FusionOptimizer(GraphRewriter):
if is_ancestor and (client_bitflag & subgraph_bitset): if is_ancestor and (client_bitflag & subgraph_bitset):
continue continue
if client in next_fuseable_clients: if client in next_fuseable_clients:
heappush(fuseables_nodes_queue, (client_bitflag, client)) heappush(
fuseables_nodes_queue,
(ancestors_bitsets[client], client_bitflag, client),
)
else: else:
# If a client is not in the node's fuseable clients set, it's nto fuseable with it, # If a client is not in the node's fuseable clients set, it's nto fuseable with it,
# nor any of its clients. But we don't need to keep track of those as any downstream # nor any of its clients. But we don't need to keep track of those as any downstream
# client we may consider later will also depend on this unfuseable client and be rejected # client we may consider later will also depend on this unfuseable client and be rejected
unfuseable_clients_bitset |= client_bitflag unfuseable_clients_bitset |= client_bitflag
# Finished exploring this subgraph # Finished expansion of subgraph
all_subgraphs_bitset |= subgraph_bitset
if subgraph_bitset == starting_bitflag: if subgraph_bitset == starting_bitflag:
# We ended were we started, no fusion possible # We ended were we started, no fusion possible
continue continue
...@@ -816,6 +827,18 @@ class FusionOptimizer(GraphRewriter): ...@@ -816,6 +827,18 @@ class FusionOptimizer(GraphRewriter):
for out in subgraph_outputs: for out in subgraph_outputs:
fuseable_clients.pop(out.owner, None) fuseable_clients.pop(out.owner, None)
# When we fuse multi-output subgraphs, we also need to fuse the dependencies of successor nodes.
# Nodes that previously depended on a subset of the fused outputs, now depend on all of them.
if len(subgraph_outputs) > 1:
subgraph_and_ancestors = (
subgraph_bitset | unfuseable_ancestors_bitset
)
ancestors_bitsets |= (
(node, node_ancestors_bitset | subgraph_and_ancestors)
for node, node_ancestors_bitset in ancestors_bitsets.items()
if node_ancestors_bitset & subgraph_bitset
)
# Add new subgraph to sorted_subgraphs # Add new subgraph to sorted_subgraphs
# Because we start from sink nodes in reverse topological order, most times new subgraphs # Because we start from sink nodes in reverse topological order, most times new subgraphs
# don't depend on previous subgraphs, so we can just append them at the end. # don't depend on previous subgraphs, so we can just append them at the end.
...@@ -828,8 +851,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -828,8 +851,7 @@ class FusionOptimizer(GraphRewriter):
else: else:
# But not here, so we need to find the right position for insertion. # But not here, so we need to find the right position for insertion.
# We iterate through the previous subgraphs in topological order (reverse of the stored order). # We iterate through the previous subgraphs in topological order (reverse of the stored order).
# We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again. # We cumulatively exclude each subgraph_bitset and perform the same dependency check again, until it passes.
# The (index + 1) of the firs iteration where the check passes is the correct insertion position.
remaining_subgraphs_bitset = all_subgraphs_bitset remaining_subgraphs_bitset = all_subgraphs_bitset
for index, (other_subgraph_bitset, _) in enumerate( for index, (other_subgraph_bitset, _) in enumerate(
reversed(sorted_subgraphs) reversed(sorted_subgraphs)
...@@ -840,12 +862,20 @@ class FusionOptimizer(GraphRewriter): ...@@ -840,12 +862,20 @@ class FusionOptimizer(GraphRewriter):
unfuseable_ancestors_bitset & remaining_subgraphs_bitset unfuseable_ancestors_bitset & remaining_subgraphs_bitset
): ):
break # bingo break # bingo
else: # no-break
raise RuntimeError(
"Failed to find insertion point for fused subgraph"
)
sorted_subgraphs.insert( sorted_subgraphs.insert(
-(index + 1), -(index + 1),
(subgraph_bitset, (subgraph_inputs, subgraph_outputs)), (subgraph_bitset, (subgraph_inputs, subgraph_outputs)),
) )
# yield from sorted_subgraphs, discarding the subgraph_bitset # Add subgraph to all_subgraphs_bitset
all_subgraphs_bitset |= subgraph_bitset
# Finished exploring the whole graph
# Yield from sorted_subgraphs, discarding the subgraph_bitset
yield from (io for _, io in sorted_subgraphs) yield from (io for _, io in sorted_subgraphs)
max_operands = elemwise_max_operands_fct(None) max_operands = elemwise_max_operands_fct(None)
......
...@@ -1426,6 +1426,43 @@ class TestFusion: ...@@ -1426,6 +1426,43 @@ class TestFusion:
np.log(1 - np.exp(-2)), np.log(1 - np.exp(-2)),
) )
def test_joint_circular_dependency(self):
# Test a case where fused subgraphs could induce a circular dependency
x = matrix("x")
neg = pt.neg(x)
eq = pt.eq(x.sum(axis=0), 0)
sub = pt.sub(eq, neg)
exp = pt.exp(neg.sum(axis=0))
# We test arbitrary add and output orders, to make sure our algorithm
# is robust to valid toposort variations.
for add_order in [(exp, eq), (eq, exp)]:
add = pt.add(*add_order)
# The naive fused graphs to consider are {sub, neg} and {add, exp, eq},
# which is not valid because sub depends on eq, while add/exp depends on neg.
# Instead, we can either fuse both {sub, neg} and {add, exp} or just {add, exp, eq}
for out_order in [(sub, add), (add, sub)]:
fgraph = FunctionGraph([x], out_order, clone=True)
_, nb_fused, nb_replaced, *_ = FusionOptimizer().apply(fgraph)
# (nb_fused, nb_replaced) would be (2, 5) if we did the invalid fusion
assert (nb_fused, nb_replaced) in ((2, 4), (1, 3))
fused_nodes = {
frozenset(
scalar_n.op for scalar_n in n.op.scalar_op.fgraph.apply_nodes
)
for n in fgraph.apply_nodes
if isinstance(n.op, Elemwise)
and isinstance(n.op.scalar_op, Composite)
}
if nb_fused == 1:
assert fused_nodes == {frozenset((ps.add, ps.exp, ps.eq))}
else:
assert fused_nodes == {
frozenset((ps.sub, ps.neg)),
frozenset((ps.add, ps.exp)),
}
class TimesN(ps.basic.UnaryScalarOp): class TimesN(ps.basic.UnaryScalarOp):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论