提交 066307f0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster graph traversal functions

* Avoid reversing inputs as we traverse graph * Simplify io_toposort without ordering (and refactor into its own function) * Removes client side-effect on previous toposort functions * Remove duplicated logic across methods
上级 f1a2ac66
......@@ -5,7 +5,6 @@ import warnings
from collections.abc import (
Hashable,
Iterable,
Reversible,
Sequence,
)
from copy import copy
......@@ -961,7 +960,7 @@ def clone_node_and_cache(
def clone_get_equiv(
inputs: Iterable[Variable],
outputs: Reversible[Variable],
outputs: Iterable[Variable],
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
......@@ -1002,7 +1001,7 @@ def clone_get_equiv(
Keywords passed to `Apply.clone_with_new_inputs`.
"""
from pytensor.graph.traversal import io_toposort
from pytensor.graph.traversal import toposort
if memo is None:
memo = {}
......@@ -1018,7 +1017,7 @@ def clone_get_equiv(
memo.setdefault(input, input)
# go through the inputs -> outputs graph cloning as we go
for apply in io_toposort(inputs, outputs):
for apply in toposort(outputs, blockers=inputs):
for input in apply.inputs:
if input not in memo:
if not isinstance(input, Constant) and copy_orphans:
......
......@@ -10,7 +10,7 @@ import numpy as np
import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable
from pytensor.graph.traversal import io_toposort
from pytensor.graph.traversal import toposort
from pytensor.graph.utils import InconsistencyError
......@@ -340,11 +340,11 @@ class Feature:
class Bookkeeper(Feature):
def on_attach(self, fgraph):
for node in io_toposort(fgraph.inputs, fgraph.outputs):
for node in toposort(fgraph.outputs):
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph):
for node in io_toposort(fgraph.inputs, fgraph.outputs):
for node in toposort(fgraph.outputs):
self.on_prune(fgraph, node, "Bookkeeper.detach")
......
......@@ -19,7 +19,8 @@ from pytensor.graph.op import Op
from pytensor.graph.traversal import (
applys_between,
graph_inputs,
io_toposort,
toposort,
toposort_with_orderings,
vars_between,
)
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
......@@ -366,7 +367,7 @@ class FunctionGraph(MetaObject):
# new nodes, so we use all variables we know of as if they were the
# input set. (The functions in the graph module only use the input set
# to know where to stop going down.)
new_nodes = io_toposort(self.variables, apply_node.outputs)
new_nodes = tuple(toposort(apply_node.outputs, blockers=self.variables))
if check:
for node in new_nodes:
......@@ -759,7 +760,7 @@ class FunctionGraph(MetaObject):
# No sorting is necessary
return list(self.apply_nodes)
return io_toposort(self.inputs, self.outputs, self.orderings())
return list(toposort_with_orderings(self.outputs, orderings=self.orderings()))
def orderings(self) -> dict[Apply, list[Apply]]:
"""Return a map of node to node evaluation dependencies.
......
......@@ -10,7 +10,10 @@ from pytensor.graph.basic import (
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.traversal import io_toposort, truncated_graph_inputs
from pytensor.graph.traversal import (
toposort,
truncated_graph_inputs,
)
ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable]
......@@ -295,7 +298,7 @@ def vectorize_graph(
new_inputs = [replace.get(inp, inp) for inp in inputs]
vect_vars = dict(zip(inputs, new_inputs, strict=True))
for node in io_toposort(inputs, seq_outputs):
for node in toposort(seq_outputs, blockers=inputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):
......
......@@ -27,7 +27,7 @@ from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
from pytensor.graph.traversal import applys_between, io_toposort, vars_between
from pytensor.graph.traversal import applys_between, toposort, vars_between
from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten
......@@ -2010,7 +2010,7 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
callback_before = fgraph.execute_callbacks_time
nb_nodes_start = len(fgraph.apply_nodes)
t0 = time.perf_counter()
q = deque(io_toposort(fgraph.inputs, start_from))
q = deque(toposort(start_from))
io_t = time.perf_counter() - t0
def importer(node):
......@@ -2341,7 +2341,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
changed |= apply_cleanup(iter_cleanup_sub_profs)
topo_t0 = time.perf_counter()
q = deque(io_toposort(fgraph.inputs, start_from))
q = deque(toposort(start_from))
io_toposort_timing.append(time.perf_counter() - topo_t0)
nb_nodes.append(len(q))
......
from collections import deque
from collections.abc import (
Callable,
Collection,
Generator,
Iterable,
Iterator,
Reversible,
Sequence,
)
from typing import (
Literal,
TypeVar,
cast,
overload,
)
from pytensor.graph.basic import Apply, Constant, Node, Variable
from pytensor.misc.ordered_set import OrderedSet
T = TypeVar("T", bound=Node)
NodeAndChildren = tuple[T, Iterable[T] | None]
@overload
def walk(
nodes: Iterable[T],
expand: Callable[[T], Iterable[T] | None],
bfs: bool = True,
return_children: Literal[False] = False,
) -> Generator[T, None, None]: ...
@overload
def walk(
nodes: Iterable[T],
expand: Callable[[T], Iterable[T] | None],
bfs: bool,
return_children: Literal[True],
) -> Generator[NodeAndChildren, None, None]: ...
def walk(
nodes: Iterable[T],
expand: Callable[[T], Iterable[T] | None],
bfs: bool = True,
return_children: bool = False,
hash_fn: Callable[[T], int] = id,
) -> Generator[T | NodeAndChildren, None, None]:
r"""Walk through a graph, either breadth- or depth-first.
......@@ -44,9 +58,6 @@ def walk(
return_children
If ``True``, each output node will be accompanied by the output of
`expand` (i.e. the corresponding child nodes).
hash_fn
The function used to produce hashes of the elements in `nodes`.
The default is ``id``.
Notes
-----
......@@ -55,39 +66,39 @@ def walk(
"""
rval_set: set[T] = set()
nodes = deque(nodes)
rval_set: set[int] = set()
nodes_pop: Callable[[], T]
if bfs:
nodes_pop = nodes.popleft
else:
nodes_pop = nodes.pop
while nodes:
node: T = nodes_pop()
node_hash: int = hash_fn(node)
if node_hash not in rval_set:
rval_set.add(node_hash)
new_nodes: Iterable[T] | None = expand(node)
if return_children:
yield node, new_nodes
else:
yield node
if new_nodes:
nodes.extend(new_nodes)
nodes_pop: Callable[[], T] = nodes.popleft if bfs else nodes.pop
node: T
new_nodes: Iterable[T] | None
try:
if return_children:
while True:
node = nodes_pop()
if node not in rval_set:
new_nodes = expand(node)
yield node, new_nodes
rval_set.add(node)
if new_nodes:
nodes.extend(new_nodes)
else:
while True:
node = nodes_pop()
if node not in rval_set:
yield node
rval_set.add(node)
new_nodes = expand(node)
if new_nodes:
nodes.extend(new_nodes)
except IndexError:
return None
def ancestors(
graphs: Iterable[Variable], blockers: Collection[Variable] | None = None
graphs: Iterable[Variable],
blockers: Iterable[Variable] | None = None,
) -> Generator[Variable, None, None]:
r"""Return the variables that contribute to those in given graphs (inclusive).
r"""Return the variables that contribute to those in given graphs (inclusive), stopping at blockers.
Parameters
----------
......@@ -101,21 +112,52 @@ def ancestors(
Yields
------
`Variable`\s
All input nodes, in the order found by a left-recursive depth-first
search started at the nodes in `graphs`.
All ancestor variables, in the order found by a right-recursive depth-first search
started at the variables in `graphs`.
"""
def expand(r: Variable) -> Iterator[Variable] | None:
if r.owner and (not blockers or r not in blockers):
return reversed(r.owner.inputs)
return None
seen = set()
queue = list(graphs)
try:
if blockers:
blockers = frozenset(blockers)
while True:
if (var := queue.pop()) not in seen:
yield var
seen.add(var)
if var not in blockers and (apply := var.owner) is not None:
queue.extend(apply.inputs)
else:
while True:
if (var := queue.pop()) not in seen:
yield var
seen.add(var)
if (apply := var.owner) is not None:
queue.extend(apply.inputs)
except IndexError:
return
variable_ancestors = ancestors
yield from cast(Generator[Variable, None, None], walk(graphs, expand, False))
def apply_ancestors(
graphs: Iterable[Variable],
blockers: Iterable[Variable] | None = None,
) -> Generator[Apply, None, None]:
"""Return the Apply nodes that contribute to those in given graphs (inclusive)."""
seen = {None} # This filters out Variables without an owner
for var in ancestors(graphs, blockers):
# For multi-output nodes, we'll see multiple variables
# but we should only yield the Apply once
if (apply := var.owner) not in seen:
yield apply
seen.add(apply)
return
def graph_inputs(
graphs: Iterable[Variable], blockers: Collection[Variable] | None = None
graphs: Iterable[Variable], blockers: Iterable[Variable] | None = None
) -> Generator[Variable, None, None]:
r"""Return the inputs required to compute the given Variables.
......@@ -130,11 +172,10 @@ def graph_inputs(
Yields
------
Input nodes with no owner, in the order found by a left-recursive
depth-first search started at the nodes in `graphs`.
Input nodes with no owner, in the order found by a breath first search started at the nodes in `graphs`.
"""
yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
yield from (var for var in ancestors(graphs, blockers) if var.owner is None)
def explicit_graph_inputs(
......@@ -177,12 +218,12 @@ def explicit_graph_inputs(
from pytensor.compile.sharedvalue import SharedVariable
if isinstance(graph, Variable):
graph = [graph]
graph = (graph,)
return (
v
for v in graph_inputs(graph)
if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable)
var
for var in ancestors(graph)
if var.owner is None and not isinstance(var, Constant | SharedVariable)
)
......@@ -191,6 +232,11 @@ def vars_between(
) -> Generator[Variable, None, None]:
r"""Extract the `Variable`\s within the sub-graph between input and output nodes.
Notes
-----
This function is like ancestors(outs, blockers=ins),
except it can also yield disconnected output variables from multi-output apply nodes.
Parameters
----------
ins
......@@ -207,20 +253,19 @@ def vars_between(
"""
ins = set(ins)
def expand(r: Variable) -> Iterable[Variable] | None:
if r.owner and r not in ins:
return reversed(r.owner.inputs + r.owner.outputs)
def expand(var: Variable, ins=frozenset(ins)) -> Iterable[Variable] | None:
if var.owner is not None and var not in ins:
return (*var.owner.inputs, *var.owner.outputs)
return None
yield from cast(Generator[Variable, None, None], walk(outs, expand))
# With bfs = False, it iterates similarly to ancestors
yield from walk(outs, expand, bfs=False)
def orphans_between(
ins: Collection[Variable], outs: Iterable[Variable]
ins: Iterable[Variable], outs: Iterable[Variable]
) -> Generator[Variable, None, None]:
r"""Extract the `Variable`\s not within the sub-graph between input and output nodes.
r"""Extract the root `Variable`\s not within the sub-graph between input and output nodes.
Parameters
----------
......@@ -245,14 +290,23 @@ def orphans_between(
[y]
"""
yield from (r for r in vars_between(ins, outs) if r.owner is None and r not in ins)
ins = frozenset(ins)
yield from (
var
for var in vars_between(ins, outs)
if ((var.owner is None) and (var not in ins))
)
def applys_between(
ins: Collection[Variable], outs: Iterable[Variable]
ins: Iterable[Variable], outs: Iterable[Variable]
) -> Generator[Apply, None, None]:
r"""Extract the `Apply`\s contained within the sub-graph between given input and output variables.
Notes
-----
This is identical to apply_ancestors(outs, blockers=ins)
Parameters
----------
ins : list
......@@ -268,12 +322,10 @@ def applys_between(
owners of the `Variable`\s in `ins`.
"""
yield from (
r.owner for r in vars_between(ins, outs) if r not in ins and r.owner is not None
)
return apply_ancestors(outs, blockers=ins)
def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool:
def apply_depends_on(apply: Apply, depends_on: Apply | Iterable[Apply]) -> bool:
"""Determine if any `depends_on` is in the graph given by ``apply``.
Parameters
......@@ -288,51 +340,47 @@ def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> boo
bool
"""
computed = set()
todo = [apply]
if not isinstance(depends_on, Collection):
depends_on = {depends_on}
if isinstance(depends_on, Apply):
depends_on = frozenset((depends_on,))
else:
depends_on = set(depends_on)
while todo:
cur = todo.pop()
if cur.outputs[0] in computed:
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
if cur in depends_on:
return True
else:
todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner)
return False
depends_on = frozenset(depends_on)
return (apply in depends_on) or any(
apply in depends_on for apply in apply_ancestors(apply.inputs)
)
def variable_depends_on(
variable: Variable, depends_on: Variable | Collection[Variable]
variable: Variable, depends_on: Variable | Iterable[Variable]
) -> bool:
"""Determine if any `depends_on` is in the graph given by ``variable``.
Notes
-----
The interpretation of dependency is done at a variable level.
A variable may depend on some output variables from a multi-output apply node but not others.
Parameters
----------
variable: Variable
Node to check
depends_on: Collection[Variable]
T to check
depends_on: Iterable[Variable]
Nodes to check dependency on
Returns
-------
bool
"""
if not isinstance(depends_on, Collection):
depends_on = {depends_on}
if isinstance(depends_on, Variable):
depends_on_set = frozenset((depends_on,))
else:
depends_on = set(depends_on)
return any(interim in depends_on for interim in ancestors([variable]))
depends_on_set = frozenset(depends_on)
return any(var in depends_on_set for var in variable_ancestors([variable]))
def truncated_graph_inputs(
outputs: Sequence[Variable],
ancestors_to_include: Collection[Variable] | None = None,
ancestors_to_include: Iterable[Variable] | None = None,
) -> list[Variable]:
"""Get the truncate graph inputs.
......@@ -345,9 +393,9 @@ def truncated_graph_inputs(
Parameters
----------
outputs : Collection[Variable]
outputs : Iterable[Variable]
Variable to get conditions for
ancestors_to_include : Optional[Collection[Variable]]
ancestors_to_include : Optional[Iterable[Variable]]
Additional ancestors to assume, by default None
Returns
......@@ -405,88 +453,136 @@ def truncated_graph_inputs(
n - (c) - (o/c)
"""
# simple case, no additional ancestors to include
truncated_inputs: list[Variable] = list()
# blockers have known independent variables and ancestors to include
candidates = list(outputs)
if not ancestors_to_include: # None or empty
seen: set[Variable] = set()
# simple case, no additional ancestors to include
if not ancestors_to_include:
# just filter out unique variables
for variable in candidates:
if variable not in truncated_inputs:
for variable in outputs:
if variable not in seen:
seen.add(variable)
truncated_inputs.append(variable)
# no more actions are needed
return truncated_inputs
# blockers have known independent variables and ancestors to include
blockers: set[Variable] = set(ancestors_to_include)
# variables that go here are under check already, do not repeat the loop for them
seen: set[Variable] = set()
# enforce O(1) check for variable in ancestors to include
ancestors_to_include = blockers.copy()
candidates = list(outputs)
try:
while True:
if (variable := candidates.pop()) not in seen:
seen.add(variable)
# check if the variable is independent, never go above blockers;
# blockers are independent variables and ancestors to include
if variable in ancestors_to_include:
# ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs
truncated_inputs.append(variable)
# if the ancestors to include is still dependent on other ancestors we need to go above,
# FIXME: This seems wrong? The other ancestors above are either redundant given this variable,
# or another path leads to them and the special casing isn't needed
# It seems the only reason we are expanding on these inputs is to find other ancestors_to_include
# (instead of treating them as disconnected), but this may yet cause other unrelated variables
# to become "independent" in the process
if variable_depends_on(variable, ancestors_to_include - {variable}):
# owner can never be None for a dependent variable
candidates.extend(
n for n in variable.owner.inputs if n not in seen
)
else:
# A regular variable to check
# if we've found an independent variable and it is not in blockers so far
# it is a new independent variable not present in ancestors to include
if variable_depends_on(variable, blockers):
# If it's not an independent variable, inputs become candidates
candidates.extend(variable.owner.inputs)
else:
# otherwise it's a truncated input itself
truncated_inputs.append(variable)
# all regular variables fall to blockers
# 1. it is dependent - we already expanded on the inputs, nothing to do if we find it again
# 2. it is independent - this is a truncated input, search for other nodes can stop here
blockers.add(variable)
except IndexError: # pop from an empty list
pass
while candidates:
# on any new candidate
variable = candidates.pop()
# we've looked into this variable already
if variable in seen:
continue
# check if the variable is independent, never go above blockers;
# blockers are independent variables and ancestors to include
elif variable in ancestors_to_include:
# The case where variable is in ancestors to include so we check if it depends on others
# it should be removed from the blockers to check against the rest
dependent = variable_depends_on(variable, ancestors_to_include - {variable})
# ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs
truncated_inputs.append(variable)
if dependent:
# if the ancestors to include is still dependent we need to go above, the search is not yet finished
# owner can never be None for a dependent variable
candidates.extend(n for n in variable.owner.inputs if n not in seen)
else:
# A regular variable to check
dependent = variable_depends_on(variable, blockers)
# all regular variables fall to blockers
# 1. it is dependent - further search irrelevant
# 2. it is independent - the search variable is inside the closure
blockers.add(variable)
# if we've found an independent variable and it is not in blockers so far
# it is a new independent variable not present in ancestors to include
if dependent:
# populate search if it's not an independent variable
# owner can never be None for a dependent variable
candidates.extend(n for n in variable.owner.inputs if n not in seen)
else:
# otherwise, do not search beyond
truncated_inputs.append(variable)
# add variable to seen, no point in checking it once more
seen.add(variable)
return truncated_inputs
@overload
def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], OrderedSet | list[T]],
compute_deps_cache: None,
deps_cache: None,
clients: dict[T, list[T]] | None,
) -> list[T]: ...
def walk_toposort(
graphs: Iterable[T],
deps: Callable[[T], Iterable[T] | None],
) -> Generator[T, None, None]:
"""Perform a topological sort of all nodes starting from a given node.
Parameters
----------
graphs:
An iterable of nodes from which to start the topological sort.
deps : callable
A Python function that takes a node as input and returns its dependence.
@overload
def general_toposort(
outputs: Iterable[T],
deps: None,
compute_deps_cache: Callable[[T], OrderedSet | list[T] | None],
deps_cache: dict[T, list[T]] | None,
clients: dict[T, list[T]] | None,
) -> list[T]: ...
Notes
-----
``deps(i)`` should behave like a pure function (no funny business with internal state).
The order of the return value list is determined by the order of nodes
returned by the `deps` function.
"""
# Cache the dependencies (ancestors) as we iterate over the nodes with the deps function
deps_cache: dict[T, list[T]] = {}
def compute_deps_cache(obj, deps_cache=deps_cache):
if obj in deps_cache:
return deps_cache[obj]
d = deps_cache[obj] = deps(obj) or []
return d
clients: dict[T, list[T]] = {}
sources: deque[T] = deque()
total_nodes = 0
for node, children in walk(
graphs, compute_deps_cache, bfs=False, return_children=True
):
total_nodes += 1
# Mypy doesn't know that toposort will not return `None` because of our `or []` in the `compute_deps_cache`
for child in children: # type: ignore
clients.setdefault(child, []).append(node)
if not deps_cache[node]:
# Add nodes without dependencies to the stack
sources.append(node)
rset: set[T] = set()
try:
while True:
if (node := sources.popleft()) not in rset:
yield node
total_nodes -= 1
rset.add(node)
# Iterate over each client node (that is, it depends on the current node)
for client in clients.get(node, []):
# Remove itself from the dependent (ancestor) list of each client
d = deps_cache[client] = [
a for a in deps_cache[client] if a is not node
]
if not d:
# If there are no dependencies left to visit for this node, add it to the stack
sources.append(client)
except IndexError:
pass
if total_nodes != 0:
raise ValueError("graph contains cycles")
def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], OrderedSet | list[T]] | None,
compute_deps_cache: Callable[[T], OrderedSet | list[T] | None] | None = None,
deps: Callable[[T], Iterable[T] | None],
compute_deps_cache: Callable[[T], Iterable[T] | None] | None = None,
deps_cache: dict[T, list[T]] | None = None,
clients: dict[T, list[T]] | None = None,
) -> list[T]:
......@@ -499,93 +595,117 @@ def general_toposort(
compute_deps_cache : optional
If provided, `deps_cache` should also be provided. This is a function like
`deps`, but that also caches its results in a ``dict`` passed as `deps_cache`.
deps_cache : dict
A ``dict`` mapping nodes to their children. This is populated by
`compute_deps_cache`.
clients : dict
If a ``dict`` is passed, it will be filled with a mapping of
nodes-to-clients for each node in the subgraph.
Notes
-----
This is a simple wrapper around `walk_toposort` for backwards compatibility
``deps(i)`` should behave like a pure function (no funny business with
internal state).
``deps(i)`` will be cached by this function (to be fast).
The order of the return value list is determined by the order of nodes
returned by the `deps` function.
"""
# TODO: Deprecate me later
if compute_deps_cache is not None:
raise ValueError("compute_deps_cache is no longer supported")
if deps_cache is not None:
raise ValueError("deps_cache is no longer supported")
if clients is not None:
raise ValueError("clients is no longer supported")
return list(walk_toposort(outputs, deps))
The second option removes a Python function call, and allows for more
specialized code, so it can be faster.
"""
if compute_deps_cache is None:
if deps_cache is None:
deps_cache = {}
def _compute_deps_cache_(io):
if io not in deps_cache:
d = deps(io)
if d:
if not isinstance(d, list | OrderedSet):
raise TypeError(
"Non-deterministic collections found; make"
" toposort non-deterministic."
)
deps_cache[io] = list(d)
else:
deps_cache[io] = None
def toposort(
graphs: Iterable[Variable],
blockers: Iterable[Variable] | None = None,
) -> Generator[Apply, None, None]:
"""Topologically sort of Apply nodes between graphs (outputs) and blockers (inputs).
return d
else:
return deps_cache[io]
This is a streamlined version of `io_toposort_generator` when no additional ordering
constraints are needed.
"""
_compute_deps_cache = _compute_deps_cache_
# We can put blocker variables in computed, as we only return apply nodes
computed = set(blockers or ())
todo = list(graphs)
try:
while True:
if (cur := todo.pop()) not in computed and (apply := cur.owner) is not None:
uncomputed_inputs = tuple(
i
for i in apply.inputs
if (i not in computed and i.owner is not None)
)
if not uncomputed_inputs:
yield apply
computed.update(apply.outputs)
else:
todo.append(cur)
todo.extend(uncomputed_inputs)
except IndexError: # queue is empty
return
else:
_compute_deps_cache = compute_deps_cache
if deps_cache is None:
raise ValueError("deps_cache cannot be None")
def toposort_with_orderings(
graphs: Iterable[Variable],
*,
blockers: Iterable[Variable] | None = None,
orderings: dict[Apply, list[Apply]] | None = None,
) -> Generator[Apply, None, None]:
"""Perform topological of nodes between blocker (input) and graphs (output) variables with arbitrary extra orderings
search_res: list[NodeAndChildren] = cast(
list[NodeAndChildren],
list(walk(outputs, _compute_deps_cache, bfs=False, return_children=True)),
)
Extra orderings can be used to force sorting of variables that are not naturally related in the graph.
This can be used by inplace optimizations to ensure a variable is only destroyed after all other uses.
Those other uses show up as dependencies of the destroying node, in the orderings dictionary.
_clients: dict[T, list[T]] = {}
sources: deque[T] = deque()
search_res_len = len(search_res)
for snode, children in search_res:
if children:
for child in children:
_clients.setdefault(child, []).append(snode)
if not deps_cache.get(snode):
sources.append(snode)
if clients is not None:
clients.update(_clients)
Parameters
----------
graphs : list or tuple of Variable instances
Graph inputs.
outputs : list or tuple of Apply instances
Graph outputs.
orderings : dict
Keys are `Apply` or `Variable` instances, values are lists of `Apply` or `Variable` instances.
rset: set[T] = set()
rlist: list[T] = []
while sources:
node: T = sources.popleft()
if node not in rset:
rlist.append(node)
rset.add(node)
for client in _clients.get(node, []):
d = [a for a in deps_cache[client] if a is not node]
deps_cache[client] = d
if not d:
sources.append(client)
if len(rlist) != search_res_len:
raise ValueError("graph contains cycles")
"""
if not orderings:
# Faster branch
yield from toposort(graphs, blockers=blockers)
return rlist
else:
# the inputs are used to decide where to stop expanding
if blockers:
def compute_deps(obj, blocker_set=frozenset(blockers), orderings=orderings):
if obj in blocker_set:
return None
if isinstance(obj, Apply):
return [*obj.inputs, *orderings.get(obj, [])]
else:
if (apply := obj.owner) is not None:
return [apply, *orderings.get(apply, [])]
else:
return orderings.get(obj, [])
else:
# mypy doesn't like conditional functions with different signatures,
# but passing the globals as optional is faster
def compute_deps(obj, orderings=orderings): # type: ignore[misc]
if isinstance(obj, Apply):
return [*obj.inputs, *orderings.get(obj, [])]
else:
if (apply := obj.owner) is not None:
return [apply, *orderings.get(apply, [])]
else:
return orderings.get(obj, [])
yield from (
apply
for apply in walk_toposort(graphs, deps=compute_deps)
# mypy doesn't understand that our generator will return both Apply and Variables
if isinstance(apply, Apply) # type: ignore
)
def io_toposort(
......@@ -594,7 +714,11 @@ def io_toposort(
orderings: dict[Apply, list[Apply]] | None = None,
clients: dict[Variable, list[Variable]] | None = None,
) -> list[Apply]:
"""Perform topological sort from input and output nodes.
"""Perform topological of nodes between input and output variables.
Notes
-----
This is just a wrapper around `toposort_with_extra_orderings` for backwards compatibility
Parameters
----------
......@@ -604,96 +728,16 @@ def io_toposort(
Graph outputs.
orderings : dict
Keys are `Apply` instances, values are lists of `Apply` instances.
clients : dict
If provided, it will be filled with mappings of nodes-to-clients for
each node in the subgraph that is sorted.
"""
if not orderings and clients is None: # ordering can be None or empty dict
# Specialized function that is faster when more then ~10 nodes
# when no ordering.
# Do a new stack implementation with the vm algo.
# This will change the order returned.
computed = set(inputs)
todo = [o.owner for o in reversed(outputs) if o.owner]
order = []
while todo:
cur = todo.pop()
if all(out in computed for out in cur.outputs):
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
order.append(cur)
else:
todo.append(cur)
todo.extend(
i.owner for i in cur.inputs if (i.owner and i not in computed)
)
return order
iset = set(inputs)
if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.
deps_cache: dict = {}
def compute_deps_cache(obj):
if obj in deps_cache:
return deps_cache[obj]
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
if rval:
deps_cache[obj] = list(rval)
else:
deps_cache[obj] = rval
else:
deps_cache[obj] = rval
return rval
topo = general_toposort(
outputs,
deps=None,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)
# TODO: Deprecate me later
if clients is not None:
raise ValueError("clients is no longer supported")
else:
# the inputs are used only here in the function that decides what
# 'predecessors' to explore
def compute_deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, None)
return rval
topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=None,
deps_cache=None,
clients=clients,
)
return [o for o in topo if isinstance(o, Apply)]
return list(toposort_with_orderings(outputs, blockers=inputs, orderings=orderings))
def get_var_by_name(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR"
graphs: Iterable[Variable], target_var_id: str
) -> tuple[Variable, ...]:
r"""Get variables in a graph using their names.
......@@ -712,21 +756,18 @@ def get_var_by_name(
"""
from pytensor.graph.op import HasInnerGraph
def expand(r) -> list[Variable] | None:
if not r.owner:
def expand(r: Variable) -> list[Variable] | None:
if (apply := r.owner) is not None:
if isinstance(apply.op, HasInnerGraph):
return [*apply.inputs, *apply.op.inner_outputs]
else:
# Mypy doesn't know these will never be None
return apply.inputs # type: ignore
else:
return None
res = list(r.owner.inputs)
if isinstance(r.owner.op, HasInnerGraph):
res.extend(r.owner.op.inner_outputs)
return res
results: tuple[Variable, ...] = ()
for var in walk(graphs, expand, False):
var = cast(Variable, var)
if target_var_id == var.name or target_var_id == var.auto_name:
results += (var,)
return results
return tuple(
var
for var in walk(graphs, expand)
if (target_var_id == var.name or target_var_id == var.auto_name)
)
......@@ -21,7 +21,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph, Op, StorageMapType
from pytensor.graph.traversal import graph_inputs, io_toposort
from pytensor.graph.traversal import graph_inputs, toposort
from pytensor.graph.utils import Scratchpad
......@@ -1103,7 +1103,7 @@ class PPrinter(Printer):
)
inv_updates = {b: a for (a, b) in updates.items()}
i = 1
for node in io_toposort([*inputs, *updates], [*outputs, *updates.values()]):
for node in toposort([*outputs, *updates.values()], [*inputs, *updates]):
for output in node.outputs:
if output in inv_updates:
name = str(inv_updates[output])
......
......@@ -13,7 +13,6 @@ from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.compile.function.types import deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph import ancestors, graph_inputs
from pytensor.graph.basic import (
Apply,
Constant,
......@@ -35,7 +34,11 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
from pytensor.graph.rewriting.utils import get_clients_at_depth
from pytensor.graph.traversal import apply_depends_on, io_toposort
from pytensor.graph.traversal import (
ancestors,
apply_depends_on,
graph_inputs,
)
from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError
from pytensor.raise_op import Assert
......@@ -220,7 +223,7 @@ def scan_push_out_non_seq(fgraph, node):
"""
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_topo = node.op.fgraph.toposort()
local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
......@@ -427,7 +430,7 @@ def scan_push_out_seq(fgraph, node):
"""
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_topo = node.op.fgraph.toposort()
local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
......@@ -840,22 +843,42 @@ def scan_push_out_add(fgraph, node):
# apply_ancestors(args.inner_outputs)
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(
node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info
)
add_of_dot_nodes = [
n
for n in op.fgraph.apply_nodes
if
(
# We have an Add
isinstance(n.op, Elemwise)
and isinstance(n.op.scalar_op, ps.Add)
and any(
(
# With a Dot input that's only used in the Add
n_inp.owner is not None
and isinstance(n_inp.owner.op, Dot)
and len(op.fgraph.clients[n_inp]) == 1
)
for n_inp in n.inputs
)
)
]
clients = {}
local_fgraph_topo = io_toposort(
args.inner_inputs, args.inner_outputs, clients=clients
if not add_of_dot_nodes:
return False
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of access
args = ScanArgs(
node.inputs,
node.outputs,
op.inner_inputs,
op.inner_outputs,
op.info,
clone=False,
)
for nd in local_fgraph_topo:
for nd in add_of_dot_nodes:
if (
isinstance(nd.op, Elemwise)
and isinstance(nd.op.scalar_op, ps.Add)
and nd.out in args.inner_out_sit_sot
nd.out in args.inner_out_sit_sot
# FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern
and inner_sitsot_only_last_step_used(fgraph, nd.out, args)
):
......@@ -863,27 +886,17 @@ def scan_push_out_add(fgraph, node):
# the add from a previous iteration of the inner function
sitsot_idx = args.inner_out_sit_sot.index(nd.out)
if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx]
assert dot_input.owner is not None and isinstance(
dot_input.owner.op, Dot
)
if (
dot_input.owner is not None
and isinstance(dot_input.owner.op, Dot)
and len(clients[dot_input]) == 1
and dot_input.owner.inputs[0].ndim == 2
and dot_input.owner.inputs[1].ndim == 2
and get_outer_ndim(dot_input.owner.inputs[0], args) == 3
get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and get_outer_ndim(dot_input.owner.inputs[1], args) == 3
):
# The optimization can be be applied in this case.
......
......@@ -59,7 +59,7 @@ import time
import numpy as np
from pytensor.graph.traversal import io_toposort
from pytensor.graph.traversal import toposort
from pytensor.tensor.rewriting.basic import register_specialize
......@@ -460,6 +460,9 @@ class GemmOptimizer(GraphRewriter):
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
nodelist = list(toposort(fgraph.outputs))
nodelist.reverse()
def on_import(new_node):
if new_node is not node:
nodelist.append(new_node)
......@@ -471,10 +474,8 @@ class GemmOptimizer(GraphRewriter):
while did_something:
nb_iter += 1
t0 = time.perf_counter()
nodelist = io_toposort(fgraph.inputs, fgraph.outputs)
time_toposort += time.perf_counter() - t0
did_something = False
nodelist.reverse()
for node in nodelist:
if not (
isinstance(node.op, Elemwise)
......
......@@ -50,23 +50,14 @@ class TestProfiling:
the_string = buf.getvalue()
lines1 = [l for l in the_string.split("\n") if "Max if linker" in l]
lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l]
if config.device == "cpu":
assert "CPU: 4112KB (4104KB)" in the_string, (lines1, lines2)
assert "CPU: 8204KB (8196KB)" in the_string, (lines1, lines2)
assert "CPU: 8208KB" in the_string, (lines1, lines2)
assert (
"Minimum peak from all valid apply node order is 4104KB"
in the_string
), (lines1, lines2)
else:
assert "CPU: 16KB (16KB)" in the_string, (lines1, lines2)
assert "GPU: 8204KB (8204KB)" in the_string, (lines1, lines2)
assert "GPU: 12300KB (12300KB)" in the_string, (lines1, lines2)
assert "GPU: 8212KB" in the_string, (lines1, lines2)
assert (
"Minimum peak from all valid apply node order is 4116KB"
in the_string
), (lines1, lines2)
# NODE: The specific numbers can change for distinct (but correct) toposort orderings
# Update the test values if a different algorithm is used
assert "CPU: 4112KB (4112KB)" in the_string, (lines1, lines2)
assert "CPU: 8204KB (8204KB)" in the_string, (lines1, lines2)
assert "CPU: 8208KB" in the_string, (lines1, lines2)
assert (
"Minimum peak from all valid apply node order is 4104KB" in the_string
), (lines1, lines2)
finally:
config.profile = config1
......
......@@ -160,7 +160,7 @@ def test_KanrenRelationSub_dot():
assert expr_opt.owner.op == pt.add
assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot)
assert fgraph_opt.inputs[0] is A_pt
assert fgraph_opt.inputs[-1] is A_pt
assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A"
assert expr_opt.owner.inputs[1].owner.op == pt.add
assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot)
......
......@@ -56,7 +56,7 @@ class TestFunctionGraph:
with pytest.raises(TypeError, match="'Variable' object is not iterable"):
FunctionGraph(var1, [var2])
with pytest.raises(TypeError, match="'Variable' object is not reversible"):
with pytest.raises(TypeError, match="'Variable' object is not iterable"):
FunctionGraph([var1], var2)
with pytest.raises(
......
......@@ -28,7 +28,7 @@ class TestCloneReplace:
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2])
f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp
assert x in f2_inp
......@@ -65,7 +65,7 @@ class TestCloneReplace:
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
......@@ -83,7 +83,7 @@ class TestCloneReplace:
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
......
......@@ -4,13 +4,17 @@ from pytensor import Variable, shared
from pytensor import tensor as pt
from pytensor.graph import Apply, ancestors, graph_inputs
from pytensor.graph.traversal import (
apply_ancestors,
apply_depends_on,
explicit_graph_inputs,
general_toposort,
get_var_by_name,
io_toposort,
orphans_between,
toposort,
toposort_with_orderings,
truncated_graph_inputs,
variable_ancestors,
variable_depends_on,
vars_between,
walk,
......@@ -36,23 +40,17 @@ class TestToposort:
o2 = MyOp(o, r5)
o2.name = "o2"
clients = {}
res = general_toposort([o2], self.prenode, clients=clients)
assert clients == {
o2.owner: [o2],
o: [o2.owner],
r5: [o2.owner],
o.owner: [o],
r1: [o.owner],
r2: [o.owner],
}
res = general_toposort([o2], self.prenode)
assert res == [r5, r2, r1, o.owner, o, o2.owner, o2]
with pytest.raises(ValueError):
general_toposort(
[o2], self.prenode, compute_deps_cache=lambda x: None, deps_cache=None
)
def circular_dependency(obj):
if obj is o:
# o2 depends on o, so o cannot depend on o2
return [o2, *self.prenode(obj)]
return self.prenode(obj)
with pytest.raises(ValueError, match="graph contains cycles"):
general_toposort([o2], circular_dependency)
res = io_toposort([r5], [o2])
assert res == [o.owner, o2.owner]
......@@ -181,16 +179,16 @@ def test_ancestors():
res = ancestors([o2], blockers=None)
res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2]
assert res_list == [o2, o1, r2, r1, r3]
res = ancestors([o2], blockers=None)
assert r3 in res
assert o1 in res
res_list = list(res)
assert res_list == [o1, r1, r2]
assert res_list == [r2, r1, r3]
res = ancestors([o2], blockers=[o1])
res_list = list(res)
assert res_list == [o2, r3, o1]
assert res_list == [o2, o1, r3]
def test_graph_inputs():
......@@ -202,7 +200,7 @@ def test_graph_inputs():
res = graph_inputs([o2], blockers=None)
res_list = list(res)
assert res_list == [r3, r1, r2]
assert res_list == [r2, r1, r3]
def test_explicit_graph_inputs():
......@@ -231,7 +229,7 @@ def test_variables_and_orphans():
vars_res_list = list(vars_res)
orphans_res_list = list(orphans_res)
assert vars_res_list == [o2, o1, r3, r2, r1]
assert vars_res_list == [o2, o1, r2, r1, r3]
assert orphans_res_list == [r3]
......@@ -408,3 +406,37 @@ def test_get_var_by_name():
exp_res = igo.fgraph.outputs[0]
assert res == exp_res
@pytest.mark.parametrize(
"func",
[
lambda x: all(variable_ancestors([x])),
lambda x: all(variable_ancestors([x], blockers=[x.clone()])),
lambda x: all(apply_ancestors([x])),
lambda x: all(apply_ancestors([x], blockers=[x.clone()])),
lambda x: all(toposort([x])),
lambda x: all(toposort([x], blockers=[x.clone()])),
lambda x: all(toposort_with_orderings([x], orderings={x: []})),
lambda x: all(
toposort_with_orderings([x], blockers=[x.clone()], orderings={x: []})
),
],
ids=[
"variable_ancestors",
"variable_ancestors_with_blockers",
"apply_ancestors",
"apply_ancestors_with_blockers)",
"toposort",
"toposort_with_blockers",
"toposort_with_orderings",
"toposort_with_orderings_and_blockers",
],
)
def test_traversal_benchmark(func, benchmark):
r1 = MyVariable(1)
out = r1
for i in range(50):
out = MyOp(out, out)
benchmark(func, out)
from itertools import chain
import numpy as np
import pytest
......@@ -490,6 +492,7 @@ def test_inplace_taps(n_steps_constant):
if isinstance(node.op, Scan)
]
# Collect inner inputs we expect to be destroyed by the step function
# Scan reorders inputs internally, so we need to check its ordering
inner_inps = scan_op.fgraph.inputs
mit_sot_inps = scan_op.inner_mitsot(inner_inps)
......@@ -501,28 +504,22 @@ def test_inplace_taps(n_steps_constant):
]
[sit_sot_inp] = scan_op.inner_sitsot(inner_inps)
inner_outs = scan_op.fgraph.outputs
mit_sot_outs = scan_op.inner_mitsot_outs(inner_outs)
[sit_sot_out] = scan_op.inner_sitsot_outs(inner_outs)
[nit_sot_out] = scan_op.inner_nitsot_outs(inner_outs)
destroyed_inputs = []
for inner_out in scan_op.fgraph.outputs:
node = inner_out.owner
dm = node.op.destroy_map
if dm:
destroyed_inputs.extend(
node.inputs[idx] for idx in chain.from_iterable(dm.values())
)
if n_steps_constant:
assert mit_sot_outs[0].owner.op.destroy_map == {
0: [mit_sot_outs[0].owner.inputs.index(oldest_mit_sot_inps[0])]
}
assert mit_sot_outs[1].owner.op.destroy_map == {
0: [mit_sot_outs[1].owner.inputs.index(oldest_mit_sot_inps[1])]
}
assert sit_sot_out.owner.op.destroy_map == {
0: [sit_sot_out.owner.inputs.index(sit_sot_inp)]
}
assert len(destroyed_inputs) == 3
assert set(destroyed_inputs) == {*oldest_mit_sot_inps, sit_sot_inp}
else:
# This is not a feature, but a current limitation
# https://github.com/pymc-devs/pytensor/issues/1283
assert mit_sot_outs[0].owner.op.destroy_map == {}
assert mit_sot_outs[1].owner.op.destroy_map == {}
assert sit_sot_out.owner.op.destroy_map == {}
assert nit_sot_out.owner.op.destroy_map == {}
assert not destroyed_inputs
@pytest.mark.parametrize(
......
......@@ -1170,8 +1170,8 @@ class TestHyp2F1Grad:
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop)
]
assert scalar_loop_op1.nin == 10 + 3 * 2 # wrt=[0, 1]
assert scalar_loop_op2.nin == 10 + 3 * 1 # wrt=[2]
assert scalar_loop_op1.nin == 10 + 3 * 1 # wrt=[2]
assert scalar_loop_op2.nin == 10 + 3 * 2 # wrt=[0, 1]
else:
[scalar_loop_op] = [
node.op.scalar_op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论