提交 b68c74dc authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Small tweaks to FusionOptimizer

上级 1d825dd6
...@@ -652,12 +652,12 @@ class FusionOptimizer(GraphRewriter): ...@@ -652,12 +652,12 @@ class FusionOptimizer(GraphRewriter):
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())}
# Root variables have `None` as owner, which we can handle with a bitset of 0 # Root variables have `None` as owner, which we can handle with a bitset of 0
ancestors_bitset = {None: 0} ancestors_bitsets = {None: 0}
for node, node_bitflag in nodes_bitflags.items(): for node, node_bitflag in nodes_bitflags.items():
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag # The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
ancestors_bitset[node] = reduce( ancestors_bitsets[node] = reduce(
or_, or_,
(ancestors_bitset[inp.owner] for inp in node.inputs), (ancestors_bitsets[inp.owner] for inp in node.inputs),
node_bitflag, node_bitflag,
) )
# Handle root and leaf nodes gracefully # Handle root and leaf nodes gracefully
...@@ -666,10 +666,12 @@ class FusionOptimizer(GraphRewriter): ...@@ -666,10 +666,12 @@ class FusionOptimizer(GraphRewriter):
nodes_bitflags[None] = 0 nodes_bitflags[None] = 0
# Nothing ever depends on the special Output nodes, so just use a new bit for all of them # Nothing ever depends on the special Output nodes, so just use a new bit for all of them
out_bitflag = 1 << len(nodes_bitflags) out_bitflag = 1 << len(nodes_bitflags)
for out in fg.outputs: nodes_bitflags |= (
for client, _ in fg_clients[out]: (client, out_bitflag)
if isinstance(client.op, Output): for out in fg.outputs
nodes_bitflags[client] = out_bitflag for client, _ in fg_clients[out]
if isinstance(client.op, Output)
)
# Start main loop to find collection of fuseable subgraphs # Start main loop to find collection of fuseable subgraphs
# We store the collection in `sorted_subgraphs`, in reverse topological order # We store the collection in `sorted_subgraphs`, in reverse topological order
...@@ -726,7 +728,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -726,7 +728,7 @@ class FusionOptimizer(GraphRewriter):
if node_bitflag & unfuseable_ancestors_bitset: if node_bitflag & unfuseable_ancestors_bitset:
# An unfuseable ancestor of the subgraph depends on this node, can't fuse # An unfuseable ancestor of the subgraph depends on this node, can't fuse
continue continue
elif ancestors_bitset[node] & unfuseable_clients_bitset: elif ancestors_bitsets[node] & unfuseable_clients_bitset:
# This node depends on an unfuseable client of the subgraph, can't fuse # This node depends on an unfuseable client of the subgraph, can't fuse
continue continue
...@@ -742,7 +744,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -742,7 +744,7 @@ class FusionOptimizer(GraphRewriter):
for inp in node.inputs: for inp in node.inputs:
ancestor_node = inp.owner ancestor_node = inp.owner
ancestor_bitflag = nodes_bitflags[ancestor_node] ancestor_bitflag = nodes_bitflags[ancestor_node]
if ancestor_bitflag & subgraph_bitset: if (not is_ancestor) and (ancestor_bitflag & subgraph_bitset):
continue continue
if node in fuseable_clients.get(ancestor_node, ()): if node in fuseable_clients.get(ancestor_node, ()):
heappush( heappush(
...@@ -752,14 +754,14 @@ class FusionOptimizer(GraphRewriter): ...@@ -752,14 +754,14 @@ class FusionOptimizer(GraphRewriter):
else: else:
# If the node is not in the ancestor's fuseable clients set, it's not fuseable with it, # If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
# nor with any of the ancestor's ancestors # nor with any of the ancestor's ancestors
unfuseable_ancestors_bitset |= ancestors_bitset[ unfuseable_ancestors_bitset |= ancestors_bitsets[
ancestor_node ancestor_node
] ]
next_fuseable_clients = fuseable_clients.get(node, ()) next_fuseable_clients = fuseable_clients.get(node, ())
for client, _ in fg_clients[node.outputs[0]]: for client, _ in fg_clients[node.outputs[0]]:
client_bitflag = nodes_bitflags[client] client_bitflag = nodes_bitflags[client]
if client_bitflag & subgraph_bitset: if is_ancestor and (client_bitflag & subgraph_bitset):
continue continue
if client in next_fuseable_clients: if client in next_fuseable_clients:
heappush(fuseables_nodes_queue, (client_bitflag, client)) heappush(fuseables_nodes_queue, (client_bitflag, client))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论