提交 566af64d authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use bitset to check ancestors more efficiently

上级 dc1e3b9c
...@@ -6,6 +6,7 @@ import typing ...@@ -6,6 +6,7 @@ import typing
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from functools import cache, reduce from functools import cache, reduce
from operator import or_
from typing import Literal from typing import Literal
from warnings import warn from warnings import warn
...@@ -29,7 +30,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -29,7 +30,7 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.rewriting.unify import OpPattern
from pytensor.graph.traversal import ancestors, toposort from pytensor.graph.traversal import toposort
from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
...@@ -659,16 +660,9 @@ class FusionOptimizer(GraphRewriter): ...@@ -659,16 +660,9 @@ class FusionOptimizer(GraphRewriter):
visited_nodes: set[Apply], visited_nodes: set[Apply],
fuseable_clients: FUSEABLE_MAPPING, fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING,
ancestors_bitset: dict[Apply, int],
toposort_index: dict[Apply, int], toposort_index: dict[Apply, int],
) -> tuple[list[Variable], list[Variable]]: ) -> tuple[list[Variable], list[Variable]]:
def variables_depend_on(
variables, depend_on, stop_search_at=None
) -> bool:
return any(
a in depend_on
for a in ancestors(variables, blockers=stop_search_at)
)
for starting_node in toposort_index: for starting_node in toposort_index:
if starting_node in visited_nodes: if starting_node in visited_nodes:
continue continue
...@@ -680,7 +674,8 @@ class FusionOptimizer(GraphRewriter): ...@@ -680,7 +674,8 @@ class FusionOptimizer(GraphRewriter):
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
unfuseable_clients_subgraph: set[Variable] = set() subgraph_inputs_ancestors_bitset = 0
unfuseable_clients_subgraph_bitset = 0
# If we need to manipulate the maps in place, we'll do a shallow copy later # If we need to manipulate the maps in place, we'll do a shallow copy later
# For now we query on the original ones # For now we query on the original ones
...@@ -712,50 +707,32 @@ class FusionOptimizer(GraphRewriter): ...@@ -712,50 +707,32 @@ class FusionOptimizer(GraphRewriter):
if must_become_output: if must_become_output:
subgraph_outputs.pop(next_out, None) subgraph_outputs.pop(next_out, None)
required_unfuseable_inputs = [ # We need to check that any inputs required by this node
inp # do not depend on other outputs of the current subgraph,
for inp in next_node.inputs # via an unfuseable path.
if next_node in unfuseable_clients_clone.get(inp) must_backtrack = (
] ancestors_bitset[next_node]
new_required_unfuseable_inputs = [ & unfuseable_clients_subgraph_bitset
inp )
for inp in required_unfuseable_inputs
if inp not in subgraph_inputs
]
must_backtrack = False
if new_required_unfuseable_inputs and subgraph_outputs:
# We need to check that any new inputs required by this node
# do not depend on other outputs of the current subgraph,
# via an unfuseable path.
if variables_depend_on(
[next_out],
depend_on=unfuseable_clients_subgraph,
stop_search_at=subgraph_outputs,
):
must_backtrack = True
if not must_backtrack: if not must_backtrack:
implied_unfuseable_clients = { implied_unfuseable_clients_bitset = reduce(
c or_,
for client in unfuseable_clients_clone.get(next_out) (
if not isinstance(client.op, Output) 1 << toposort_index[client]
for c in client.outputs for client in unfuseable_clients_clone.get(next_out)
} if not isinstance(client.op, Output)
),
new_implied_unfuseable_clients = ( 0,
implied_unfuseable_clients - unfuseable_clients_subgraph
) )
if new_implied_unfuseable_clients and subgraph_inputs: # We need to check that any inputs of the current subgraph
# We need to check that any inputs of the current subgraph # do not depend on other clients of this node,
# do not depend on other clients of this node, # via an unfuseable path.
# via an unfuseable path. must_backtrack = (
if variables_depend_on( subgraph_inputs_ancestors_bitset
subgraph_inputs, & implied_unfuseable_clients_bitset
depend_on=new_implied_unfuseable_clients, )
):
must_backtrack = True
if must_backtrack: if must_backtrack:
for inp in next_node.inputs: for inp in next_node.inputs:
...@@ -796,29 +773,24 @@ class FusionOptimizer(GraphRewriter): ...@@ -796,29 +773,24 @@ class FusionOptimizer(GraphRewriter):
# immediate dependency problems. Update subgraph # immediate dependency problems. Update subgraph
# mappings as if it next_node was part of it. # mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite # Useless inputs will be removed by the useless Composite rewrite
for inp in new_required_unfuseable_inputs:
subgraph_inputs[inp] = None
if must_become_output: if must_become_output:
subgraph_outputs[next_out] = None subgraph_outputs[next_out] = None
unfuseable_clients_subgraph.update( unfuseable_clients_subgraph_bitset |= (
new_implied_unfuseable_clients implied_unfuseable_clients_bitset
) )
# Expand through unvisited fuseable ancestors for inp in sorted(
fuseable_nodes_to_visit.extendleft( next_node.inputs,
sorted( key=lambda x: toposort_index.get(x.owner, -1),
( ):
inp.owner if next_node in unfuseable_clients_clone.get(inp, ()):
for inp in next_node.inputs # input must become an input of the subgraph since it's unfuseable with new node
if ( subgraph_inputs_ancestors_bitset |= (
inp not in required_unfuseable_inputs ancestors_bitset.get(inp.owner, 0)
and inp.owner not in visited_nodes )
) subgraph_inputs[inp] = None
), elif inp.owner not in visited_nodes:
key=toposort_index.get, # type: ignore[arg-type] fuseable_nodes_to_visit.appendleft(inp.owner)
)
)
# Expand through unvisited fuseable clients # Expand through unvisited fuseable clients
fuseable_nodes_to_visit.extend( fuseable_nodes_to_visit.extend(
...@@ -855,6 +827,8 @@ class FusionOptimizer(GraphRewriter): ...@@ -855,6 +827,8 @@ class FusionOptimizer(GraphRewriter):
visited_nodes: set[Apply], visited_nodes: set[Apply],
fuseable_clients: FUSEABLE_MAPPING, fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING,
toposort_index: dict[Apply, int],
ancestors_bitset: dict[Apply, int],
starting_nodes: set[Apply], starting_nodes: set[Apply],
updated_nodes: set[Apply], updated_nodes: set[Apply],
) -> None: ) -> None:
...@@ -865,11 +839,25 @@ class FusionOptimizer(GraphRewriter): ...@@ -865,11 +839,25 @@ class FusionOptimizer(GraphRewriter):
dropped_nodes = starting_nodes - updated_nodes dropped_nodes = starting_nodes - updated_nodes
# Remove intermediate Composite nodes from mappings # Remove intermediate Composite nodes from mappings
# And compute the ancestors bitset of the new composite node
# As well as the new toposort index for the new node
new_node_ancestor_bitset = 0
new_node_toposort_index = len(toposort_index)
for dropped_node in dropped_nodes: for dropped_node in dropped_nodes:
(dropped_out,) = dropped_node.outputs (dropped_out,) = dropped_node.outputs
fuseable_clients.pop(dropped_out, None) fuseable_clients.pop(dropped_out, None)
unfuseable_clients.pop(dropped_out, None) unfuseable_clients.pop(dropped_out, None)
visited_nodes.remove(dropped_node) visited_nodes.remove(dropped_node)
# The new composite ancestor bitset is the union
# of the ancestors of all the dropped nodes
new_node_ancestor_bitset |= ancestors_bitset[dropped_node]
# The new composite node can have the same order as the latest node that was absorbed into it
new_node_toposort_index = max(
new_node_toposort_index, toposort_index[dropped_node]
)
ancestors_bitset[new_composite_node] = new_node_ancestor_bitset
toposort_index[new_composite_node] = new_node_toposort_index
# Update fuseable information for subgraph inputs # Update fuseable information for subgraph inputs
for inp in subgraph_inputs: for inp in subgraph_inputs:
...@@ -901,12 +889,23 @@ class FusionOptimizer(GraphRewriter): ...@@ -901,12 +889,23 @@ class FusionOptimizer(GraphRewriter):
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
visited_nodes: set[Apply] = set() visited_nodes: set[Apply] = set()
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())} toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
# Create a bitset for each node of all its ancestors
# This allows to quickly check if a variable depends on a set
ancestors_bitset: dict[Apply, int] = {}
for node, index in toposort_index.items():
node_ancestor_bitset = 1 << index
for inp in node.inputs:
if (inp_node := inp.owner) is not None:
node_ancestor_bitset |= ancestors_bitset[inp_node]
ancestors_bitset[node] = node_ancestor_bitset
while True: while True:
try: try:
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
visited_nodes=visited_nodes, visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients, fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients, unfuseable_clients=unfuseable_clients,
ancestors_bitset=ancestors_bitset,
toposort_index=toposort_index, toposort_index=toposort_index,
) )
except ValueError: except ValueError:
...@@ -925,6 +924,8 @@ class FusionOptimizer(GraphRewriter): ...@@ -925,6 +924,8 @@ class FusionOptimizer(GraphRewriter):
visited_nodes=visited_nodes, visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients, fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients, unfuseable_clients=unfuseable_clients,
toposort_index=toposort_index,
ancestors_bitset=ancestors_bitset,
starting_nodes=starting_nodes, starting_nodes=starting_nodes,
updated_nodes=fg.apply_nodes, updated_nodes=fg.apply_nodes,
) )
......
...@@ -301,7 +301,8 @@ def test_debugprint(): ...@@ -301,7 +301,8 @@ def test_debugprint():
Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv" Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv"
exp_res = dedent( exp_res = dedent(
r""" r"""
Composite{(i2 + (i0 - i1))} 4 Composite{(i0 + (i1 - i2))} 4
├─ A
├─ ExpandDims{axis=0} v={0: [0]} 3 ├─ ExpandDims{axis=0} v={0: [0]} 3
""" """
f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2" f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2"
...@@ -313,17 +314,16 @@ def test_debugprint(): ...@@ -313,17 +314,16 @@ def test_debugprint():
│ ├─ B │ ├─ B
│ ├─ <Vector(float64, shape=(?,))> │ ├─ <Vector(float64, shape=(?,))>
│ └─ 0.0 │ └─ 0.0
├─ D └─ D
└─ A
Inner graphs: Inner graphs:
Composite{(i2 + (i0 - i1))} Composite{(i0 + (i1 - i2))}
← add 'o0' ← add 'o0'
├─ i2
└─ sub
├─ i0 ├─ i0
└─ i1 └─ sub
├─ i1
└─ i2
""" """
).lstrip() ).lstrip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论