提交 b68c74dc authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Small tweaks to FusionOptimizer

上级 1d825dd6
......@@ -652,12 +652,12 @@ class FusionOptimizer(GraphRewriter):
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())}
# Root variables have `None` as owner, which we can handle with a bitset of 0
ancestors_bitset = {None: 0}
ancestors_bitsets = {None: 0}
for node, node_bitflag in nodes_bitflags.items():
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
ancestors_bitset[node] = reduce(
ancestors_bitsets[node] = reduce(
or_,
(ancestors_bitset[inp.owner] for inp in node.inputs),
(ancestors_bitsets[inp.owner] for inp in node.inputs),
node_bitflag,
)
# Handle root and leaf nodes gracefully
......@@ -666,10 +666,12 @@ class FusionOptimizer(GraphRewriter):
nodes_bitflags[None] = 0
# Nothing ever depends on the special Output nodes, so just use a new bit for all of them
out_bitflag = 1 << len(nodes_bitflags)
for out in fg.outputs:
for client, _ in fg_clients[out]:
if isinstance(client.op, Output):
nodes_bitflags[client] = out_bitflag
nodes_bitflags |= (
(client, out_bitflag)
for out in fg.outputs
for client, _ in fg_clients[out]
if isinstance(client.op, Output)
)
# Start main loop to find collection of fuseable subgraphs
# We store the collection in `sorted_subgraphs`, in reverse topological order
......@@ -726,7 +728,7 @@ class FusionOptimizer(GraphRewriter):
if node_bitflag & unfuseable_ancestors_bitset:
# An unfuseable ancestor of the subgraph depends on this node, can't fuse
continue
elif ancestors_bitset[node] & unfuseable_clients_bitset:
elif ancestors_bitsets[node] & unfuseable_clients_bitset:
# This node depends on an unfuseable client of the subgraph, can't fuse
continue
......@@ -742,7 +744,7 @@ class FusionOptimizer(GraphRewriter):
for inp in node.inputs:
ancestor_node = inp.owner
ancestor_bitflag = nodes_bitflags[ancestor_node]
if ancestor_bitflag & subgraph_bitset:
if (not is_ancestor) and (ancestor_bitflag & subgraph_bitset):
continue
if node in fuseable_clients.get(ancestor_node, ()):
heappush(
......@@ -752,14 +754,14 @@ class FusionOptimizer(GraphRewriter):
else:
# If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
# nor with any of the ancestor's ancestors
unfuseable_ancestors_bitset |= ancestors_bitset[
unfuseable_ancestors_bitset |= ancestors_bitsets[
ancestor_node
]
next_fuseable_clients = fuseable_clients.get(node, ())
for client, _ in fg_clients[node.outputs[0]]:
client_bitflag = nodes_bitflags[client]
if client_bitflag & subgraph_bitset:
if is_ancestor and (client_bitflag & subgraph_bitset):
continue
if client in next_fuseable_clients:
heappush(fuseables_nodes_queue, (client_bitflag, client))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论