提交 566af64d authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use bitset to check ancestors more efficiently

上级 dc1e3b9c
......@@ -6,6 +6,7 @@ import typing
from collections import defaultdict, deque
from collections.abc import Generator, Sequence
from functools import cache, reduce
from operator import or_
from typing import Literal
from warnings import warn
......@@ -29,7 +30,7 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.graph.traversal import ancestors, toposort
from pytensor.graph.traversal import toposort
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
......@@ -659,16 +660,9 @@ class FusionOptimizer(GraphRewriter):
visited_nodes: set[Apply],
fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING,
ancestors_bitset: dict[Apply, int],
toposort_index: dict[Apply, int],
) -> tuple[list[Variable], list[Variable]]:
def variables_depend_on(
variables, depend_on, stop_search_at=None
) -> bool:
return any(
a in depend_on
for a in ancestors(variables, blockers=stop_search_at)
)
for starting_node in toposort_index:
if starting_node in visited_nodes:
continue
......@@ -680,7 +674,8 @@ class FusionOptimizer(GraphRewriter):
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
unfuseable_clients_subgraph: set[Variable] = set()
subgraph_inputs_ancestors_bitset = 0
unfuseable_clients_subgraph_bitset = 0
# If we need to manipulate the maps in place, we'll do a shallow copy later
# For now we query on the original ones
......@@ -712,50 +707,32 @@ class FusionOptimizer(GraphRewriter):
if must_become_output:
subgraph_outputs.pop(next_out, None)
required_unfuseable_inputs = [
inp
for inp in next_node.inputs
if next_node in unfuseable_clients_clone.get(inp)
]
new_required_unfuseable_inputs = [
inp
for inp in required_unfuseable_inputs
if inp not in subgraph_inputs
]
must_backtrack = False
if new_required_unfuseable_inputs and subgraph_outputs:
# We need to check that any new inputs required by this node
# do not depend on other outputs of the current subgraph,
# via an unfuseable path.
if variables_depend_on(
[next_out],
depend_on=unfuseable_clients_subgraph,
stop_search_at=subgraph_outputs,
):
must_backtrack = True
# We need to check that any inputs required by this node
# do not depend on other outputs of the current subgraph,
# via an unfuseable path.
must_backtrack = (
ancestors_bitset[next_node]
& unfuseable_clients_subgraph_bitset
)
if not must_backtrack:
implied_unfuseable_clients = {
c
for client in unfuseable_clients_clone.get(next_out)
if not isinstance(client.op, Output)
for c in client.outputs
}
new_implied_unfuseable_clients = (
implied_unfuseable_clients - unfuseable_clients_subgraph
implied_unfuseable_clients_bitset = reduce(
or_,
(
1 << toposort_index[client]
for client in unfuseable_clients_clone.get(next_out)
if not isinstance(client.op, Output)
),
0,
)
if new_implied_unfuseable_clients and subgraph_inputs:
# We need to check that any inputs of the current subgraph
# do not depend on other clients of this node,
# via an unfuseable path.
if variables_depend_on(
subgraph_inputs,
depend_on=new_implied_unfuseable_clients,
):
must_backtrack = True
# We need to check that any inputs of the current subgraph
# do not depend on other clients of this node,
# via an unfuseable path.
must_backtrack = (
subgraph_inputs_ancestors_bitset
& implied_unfuseable_clients_bitset
)
if must_backtrack:
for inp in next_node.inputs:
......@@ -796,29 +773,24 @@ class FusionOptimizer(GraphRewriter):
# immediate dependency problems. Update subgraph
# mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite
for inp in new_required_unfuseable_inputs:
subgraph_inputs[inp] = None
if must_become_output:
subgraph_outputs[next_out] = None
unfuseable_clients_subgraph.update(
new_implied_unfuseable_clients
unfuseable_clients_subgraph_bitset |= (
implied_unfuseable_clients_bitset
)
# Expand through unvisited fuseable ancestors
fuseable_nodes_to_visit.extendleft(
sorted(
(
inp.owner
for inp in next_node.inputs
if (
inp not in required_unfuseable_inputs
and inp.owner not in visited_nodes
)
),
key=toposort_index.get, # type: ignore[arg-type]
)
)
for inp in sorted(
next_node.inputs,
key=lambda x: toposort_index.get(x.owner, -1),
):
if next_node in unfuseable_clients_clone.get(inp, ()):
# input must become an input of the subgraph since it's unfuseable with new node
subgraph_inputs_ancestors_bitset |= (
ancestors_bitset.get(inp.owner, 0)
)
subgraph_inputs[inp] = None
elif inp.owner not in visited_nodes:
fuseable_nodes_to_visit.appendleft(inp.owner)
# Expand through unvisited fuseable clients
fuseable_nodes_to_visit.extend(
......@@ -855,6 +827,8 @@ class FusionOptimizer(GraphRewriter):
visited_nodes: set[Apply],
fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING,
toposort_index: dict[Apply, int],
ancestors_bitset: dict[Apply, int],
starting_nodes: set[Apply],
updated_nodes: set[Apply],
) -> None:
......@@ -865,11 +839,25 @@ class FusionOptimizer(GraphRewriter):
dropped_nodes = starting_nodes - updated_nodes
# Remove intermediate Composite nodes from mappings
# And compute the ancestors bitset of the new composite node
# As well as the new toposort index for the new node
new_node_ancestor_bitset = 0
new_node_toposort_index = len(toposort_index)
for dropped_node in dropped_nodes:
(dropped_out,) = dropped_node.outputs
fuseable_clients.pop(dropped_out, None)
unfuseable_clients.pop(dropped_out, None)
visited_nodes.remove(dropped_node)
# The new composite ancestor bitset is the union
# of the ancestors of all the dropped nodes
new_node_ancestor_bitset |= ancestors_bitset[dropped_node]
# The new composite node can have the same order as the latest node that was absorbed into it
new_node_toposort_index = max(
new_node_toposort_index, toposort_index[dropped_node]
)
ancestors_bitset[new_composite_node] = new_node_ancestor_bitset
toposort_index[new_composite_node] = new_node_toposort_index
# Update fuseable information for subgraph inputs
for inp in subgraph_inputs:
......@@ -901,12 +889,23 @@ class FusionOptimizer(GraphRewriter):
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
visited_nodes: set[Apply] = set()
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
# Create a bitset for each node of all its ancestors
# This allows to quickly check if a variable depends on a set
ancestors_bitset: dict[Apply, int] = {}
for node, index in toposort_index.items():
node_ancestor_bitset = 1 << index
for inp in node.inputs:
if (inp_node := inp.owner) is not None:
node_ancestor_bitset |= ancestors_bitset[inp_node]
ancestors_bitset[node] = node_ancestor_bitset
while True:
try:
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
ancestors_bitset=ancestors_bitset,
toposort_index=toposort_index,
)
except ValueError:
......@@ -925,6 +924,8 @@ class FusionOptimizer(GraphRewriter):
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
toposort_index=toposort_index,
ancestors_bitset=ancestors_bitset,
starting_nodes=starting_nodes,
updated_nodes=fg.apply_nodes,
)
......
......@@ -301,7 +301,8 @@ def test_debugprint():
Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv"
exp_res = dedent(
r"""
Composite{(i2 + (i0 - i1))} 4
Composite{(i0 + (i1 - i2))} 4
├─ A
├─ ExpandDims{axis=0} v={0: [0]} 3
"""
f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2"
......@@ -313,17 +314,16 @@ def test_debugprint():
│ ├─ B
│ ├─ <Vector(float64, shape=(?,))>
│ └─ 0.0
├─ D
└─ A
└─ D
Inner graphs:
Composite{(i2 + (i0 - i1))}
Composite{(i0 + (i1 - i2))}
← add 'o0'
├─ i2
└─ sub
├─ i0
└─ i1
└─ sub
├─ i1
└─ i2
"""
).lstrip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论