提交 066307f0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster graph traversal functions

* Avoid reversing inputs as we traverse graph * Simplify io_toposort without ordering (and refactor into its own function) * Removes client side-effect on previous toposort functions * Remove duplicated logic across methods
上级 f1a2ac66
......@@ -5,7 +5,6 @@ import warnings
from collections.abc import (
Hashable,
Iterable,
Reversible,
Sequence,
)
from copy import copy
......@@ -961,7 +960,7 @@ def clone_node_and_cache(
def clone_get_equiv(
inputs: Iterable[Variable],
outputs: Reversible[Variable],
outputs: Iterable[Variable],
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
......@@ -1002,7 +1001,7 @@ def clone_get_equiv(
Keywords passed to `Apply.clone_with_new_inputs`.
"""
from pytensor.graph.traversal import io_toposort
from pytensor.graph.traversal import toposort
if memo is None:
memo = {}
......@@ -1018,7 +1017,7 @@ def clone_get_equiv(
memo.setdefault(input, input)
# go through the inputs -> outputs graph cloning as we go
for apply in io_toposort(inputs, outputs):
for apply in toposort(outputs, blockers=inputs):
for input in apply.inputs:
if input not in memo:
if not isinstance(input, Constant) and copy_orphans:
......
......@@ -10,7 +10,7 @@ import numpy as np
import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable
from pytensor.graph.traversal import io_toposort
from pytensor.graph.traversal import toposort
from pytensor.graph.utils import InconsistencyError
......@@ -340,11 +340,11 @@ class Feature:
class Bookkeeper(Feature):
def on_attach(self, fgraph):
for node in io_toposort(fgraph.inputs, fgraph.outputs):
for node in toposort(fgraph.outputs):
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph):
for node in io_toposort(fgraph.inputs, fgraph.outputs):
for node in toposort(fgraph.outputs):
self.on_prune(fgraph, node, "Bookkeeper.detach")
......
......@@ -19,7 +19,8 @@ from pytensor.graph.op import Op
from pytensor.graph.traversal import (
applys_between,
graph_inputs,
io_toposort,
toposort,
toposort_with_orderings,
vars_between,
)
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
......@@ -366,7 +367,7 @@ class FunctionGraph(MetaObject):
# new nodes, so we use all variables we know of as if they were the
# input set. (The functions in the graph module only use the input set
# to know where to stop going down.)
new_nodes = io_toposort(self.variables, apply_node.outputs)
new_nodes = tuple(toposort(apply_node.outputs, blockers=self.variables))
if check:
for node in new_nodes:
......@@ -759,7 +760,7 @@ class FunctionGraph(MetaObject):
# No sorting is necessary
return list(self.apply_nodes)
return io_toposort(self.inputs, self.outputs, self.orderings())
return list(toposort_with_orderings(self.outputs, orderings=self.orderings()))
def orderings(self) -> dict[Apply, list[Apply]]:
"""Return a map of node to node evaluation dependencies.
......
......@@ -10,7 +10,10 @@ from pytensor.graph.basic import (
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.traversal import io_toposort, truncated_graph_inputs
from pytensor.graph.traversal import (
toposort,
truncated_graph_inputs,
)
ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable]
......@@ -295,7 +298,7 @@ def vectorize_graph(
new_inputs = [replace.get(inp, inp) for inp in inputs]
vect_vars = dict(zip(inputs, new_inputs, strict=True))
for node in io_toposort(inputs, seq_outputs):
for node in toposort(seq_outputs, blockers=inputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):
......
......@@ -27,7 +27,7 @@ from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
from pytensor.graph.traversal import applys_between, io_toposort, vars_between
from pytensor.graph.traversal import applys_between, toposort, vars_between
from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten
......@@ -2010,7 +2010,7 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
callback_before = fgraph.execute_callbacks_time
nb_nodes_start = len(fgraph.apply_nodes)
t0 = time.perf_counter()
q = deque(io_toposort(fgraph.inputs, start_from))
q = deque(toposort(start_from))
io_t = time.perf_counter() - t0
def importer(node):
......@@ -2341,7 +2341,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
changed |= apply_cleanup(iter_cleanup_sub_profs)
topo_t0 = time.perf_counter()
q = deque(io_toposort(fgraph.inputs, start_from))
q = deque(toposort(start_from))
io_toposort_timing.append(time.perf_counter() - topo_t0)
nb_nodes.append(len(q))
......
from collections import deque
from collections.abc import (
Callable,
Collection,
Generator,
Iterable,
Iterator,
Reversible,
Sequence,
)
from typing import (
Literal,
TypeVar,
cast,
overload,
)
from pytensor.graph.basic import Apply, Constant, Node, Variable
from pytensor.misc.ordered_set import OrderedSet
T = TypeVar("T", bound=Node)
NodeAndChildren = tuple[T, Iterable[T] | None]
@overload
def walk(
nodes: Iterable[T],
expand: Callable[[T], Iterable[T] | None],
bfs: bool = True,
return_children: Literal[False] = False,
) -> Generator[T, None, None]: ...
@overload
def walk(
nodes: Iterable[T],
expand: Callable[[T], Iterable[T] | None],
bfs: bool,
return_children: Literal[True],
) -> Generator[NodeAndChildren, None, None]: ...
def walk(
nodes: Iterable[T],
expand: Callable[[T], Iterable[T] | None],
bfs: bool = True,
return_children: bool = False,
hash_fn: Callable[[T], int] = id,
) -> Generator[T | NodeAndChildren, None, None]:
r"""Walk through a graph, either breadth- or depth-first.
......@@ -44,9 +58,6 @@ def walk(
return_children
If ``True``, each output node will be accompanied by the output of
`expand` (i.e. the corresponding child nodes).
hash_fn
The function used to produce hashes of the elements in `nodes`.
The default is ``id``.
Notes
-----
......@@ -55,39 +66,39 @@ def walk(
"""
rval_set: set[T] = set()
nodes = deque(nodes)
rval_set: set[int] = set()
nodes_pop: Callable[[], T]
if bfs:
nodes_pop = nodes.popleft
else:
nodes_pop = nodes.pop
while nodes:
node: T = nodes_pop()
node_hash: int = hash_fn(node)
if node_hash not in rval_set:
rval_set.add(node_hash)
new_nodes: Iterable[T] | None = expand(node)
nodes_pop: Callable[[], T] = nodes.popleft if bfs else nodes.pop
node: T
new_nodes: Iterable[T] | None
try:
if return_children:
while True:
node = nodes_pop()
if node not in rval_set:
new_nodes = expand(node)
yield node, new_nodes
rval_set.add(node)
if new_nodes:
nodes.extend(new_nodes)
else:
while True:
node = nodes_pop()
if node not in rval_set:
yield node
rval_set.add(node)
new_nodes = expand(node)
if new_nodes:
nodes.extend(new_nodes)
except IndexError:
return None
def ancestors(
graphs: Iterable[Variable], blockers: Collection[Variable] | None = None
graphs: Iterable[Variable],
blockers: Iterable[Variable] | None = None,
) -> Generator[Variable, None, None]:
r"""Return the variables that contribute to those in given graphs (inclusive).
r"""Return the variables that contribute to those in given graphs (inclusive), stopping at blockers.
Parameters
----------
......@@ -101,21 +112,52 @@ def ancestors(
Yields
------
`Variable`\s
All input nodes, in the order found by a left-recursive depth-first
search started at the nodes in `graphs`.
All ancestor variables, in the order found by a right-recursive depth-first search
started at the variables in `graphs`.
"""
def expand(r: Variable) -> Iterator[Variable] | None:
if r.owner and (not blockers or r not in blockers):
return reversed(r.owner.inputs)
return None
seen = set()
queue = list(graphs)
try:
if blockers:
blockers = frozenset(blockers)
while True:
if (var := queue.pop()) not in seen:
yield var
seen.add(var)
if var not in blockers and (apply := var.owner) is not None:
queue.extend(apply.inputs)
else:
while True:
if (var := queue.pop()) not in seen:
yield var
seen.add(var)
if (apply := var.owner) is not None:
queue.extend(apply.inputs)
except IndexError:
return
yield from cast(Generator[Variable, None, None], walk(graphs, expand, False))
variable_ancestors = ancestors
def apply_ancestors(
graphs: Iterable[Variable],
blockers: Iterable[Variable] | None = None,
) -> Generator[Apply, None, None]:
"""Return the Apply nodes that contribute to those in given graphs (inclusive)."""
seen = {None} # This filters out Variables without an owner
for var in ancestors(graphs, blockers):
# For multi-output nodes, we'll see multiple variables
# but we should only yield the Apply once
if (apply := var.owner) not in seen:
yield apply
seen.add(apply)
return
def graph_inputs(
graphs: Iterable[Variable], blockers: Collection[Variable] | None = None
graphs: Iterable[Variable], blockers: Iterable[Variable] | None = None
) -> Generator[Variable, None, None]:
r"""Return the inputs required to compute the given Variables.
......@@ -130,11 +172,10 @@ def graph_inputs(
Yields
------
Input nodes with no owner, in the order found by a left-recursive
depth-first search started at the nodes in `graphs`.
Input nodes with no owner, in the order found by a breath first search started at the nodes in `graphs`.
"""
yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
yield from (var for var in ancestors(graphs, blockers) if var.owner is None)
def explicit_graph_inputs(
......@@ -177,12 +218,12 @@ def explicit_graph_inputs(
from pytensor.compile.sharedvalue import SharedVariable
if isinstance(graph, Variable):
graph = [graph]
graph = (graph,)
return (
v
for v in graph_inputs(graph)
if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable)
var
for var in ancestors(graph)
if var.owner is None and not isinstance(var, Constant | SharedVariable)
)
......@@ -191,6 +232,11 @@ def vars_between(
) -> Generator[Variable, None, None]:
r"""Extract the `Variable`\s within the sub-graph between input and output nodes.
Notes
-----
This function is like ancestors(outs, blockers=ins),
except it can also yield disconnected output variables from multi-output apply nodes.
Parameters
----------
ins
......@@ -207,20 +253,19 @@ def vars_between(
"""
ins = set(ins)
def expand(r: Variable) -> Iterable[Variable] | None:
if r.owner and r not in ins:
return reversed(r.owner.inputs + r.owner.outputs)
def expand(var: Variable, ins=frozenset(ins)) -> Iterable[Variable] | None:
if var.owner is not None and var not in ins:
return (*var.owner.inputs, *var.owner.outputs)
return None
yield from cast(Generator[Variable, None, None], walk(outs, expand))
# With bfs = False, it iterates similarly to ancestors
yield from walk(outs, expand, bfs=False)
def orphans_between(
ins: Collection[Variable], outs: Iterable[Variable]
ins: Iterable[Variable], outs: Iterable[Variable]
) -> Generator[Variable, None, None]:
r"""Extract the `Variable`\s not within the sub-graph between input and output nodes.
r"""Extract the root `Variable`\s not within the sub-graph between input and output nodes.
Parameters
----------
......@@ -245,14 +290,23 @@ def orphans_between(
[y]
"""
yield from (r for r in vars_between(ins, outs) if r.owner is None and r not in ins)
ins = frozenset(ins)
yield from (
var
for var in vars_between(ins, outs)
if ((var.owner is None) and (var not in ins))
)
def applys_between(
ins: Collection[Variable], outs: Iterable[Variable]
ins: Iterable[Variable], outs: Iterable[Variable]
) -> Generator[Apply, None, None]:
r"""Extract the `Apply`\s contained within the sub-graph between given input and output variables.
Notes
-----
This is identical to apply_ancestors(outs, blockers=ins)
Parameters
----------
ins : list
......@@ -268,12 +322,10 @@ def applys_between(
owners of the `Variable`\s in `ins`.
"""
yield from (
r.owner for r in vars_between(ins, outs) if r not in ins and r.owner is not None
)
return apply_ancestors(outs, blockers=ins)
def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool:
def apply_depends_on(apply: Apply, depends_on: Apply | Iterable[Apply]) -> bool:
"""Determine if any `depends_on` is in the graph given by ``apply``.
Parameters
......@@ -288,51 +340,47 @@ def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> boo
bool
"""
computed = set()
todo = [apply]
if not isinstance(depends_on, Collection):
depends_on = {depends_on}
else:
depends_on = set(depends_on)
while todo:
cur = todo.pop()
if cur.outputs[0] in computed:
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
if cur in depends_on:
return True
if isinstance(depends_on, Apply):
depends_on = frozenset((depends_on,))
else:
todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner)
return False
depends_on = frozenset(depends_on)
return (apply in depends_on) or any(
apply in depends_on for apply in apply_ancestors(apply.inputs)
)
def variable_depends_on(
variable: Variable, depends_on: Variable | Collection[Variable]
variable: Variable, depends_on: Variable | Iterable[Variable]
) -> bool:
"""Determine if any `depends_on` is in the graph given by ``variable``.
Notes
-----
The interpretation of dependency is done at a variable level.
A variable may depend on some output variables from a multi-output apply node but not others.
Parameters
----------
variable: Variable
Node to check
depends_on: Collection[Variable]
T to check
depends_on: Iterable[Variable]
Nodes to check dependency on
Returns
-------
bool
"""
if not isinstance(depends_on, Collection):
depends_on = {depends_on}
if isinstance(depends_on, Variable):
depends_on_set = frozenset((depends_on,))
else:
depends_on = set(depends_on)
return any(interim in depends_on for interim in ancestors([variable]))
depends_on_set = frozenset(depends_on)
return any(var in depends_on_set for var in variable_ancestors([variable]))
def truncated_graph_inputs(
outputs: Sequence[Variable],
ancestors_to_include: Collection[Variable] | None = None,
ancestors_to_include: Iterable[Variable] | None = None,
) -> list[Variable]:
"""Get the truncate graph inputs.
......@@ -345,9 +393,9 @@ def truncated_graph_inputs(
Parameters
----------
outputs : Collection[Variable]
outputs : Iterable[Variable]
Variable to get conditions for
ancestors_to_include : Optional[Collection[Variable]]
ancestors_to_include : Optional[Iterable[Variable]]
Additional ancestors to assume, by default None
Returns
......@@ -405,88 +453,136 @@ def truncated_graph_inputs(
n - (c) - (o/c)
"""
# simple case, no additional ancestors to include
truncated_inputs: list[Variable] = list()
# blockers have known independent variables and ancestors to include
candidates = list(outputs)
if not ancestors_to_include: # None or empty
seen: set[Variable] = set()
# simple case, no additional ancestors to include
if not ancestors_to_include:
# just filter out unique variables
for variable in candidates:
if variable not in truncated_inputs:
for variable in outputs:
if variable not in seen:
seen.add(variable)
truncated_inputs.append(variable)
# no more actions are needed
return truncated_inputs
# blockers have known independent variables and ancestors to include
blockers: set[Variable] = set(ancestors_to_include)
# variables that go here are under check already, do not repeat the loop for them
seen: set[Variable] = set()
# enforce O(1) check for variable in ancestors to include
ancestors_to_include = blockers.copy()
while candidates:
# on any new candidate
variable = candidates.pop()
# we've looked into this variable already
if variable in seen:
continue
candidates = list(outputs)
try:
while True:
if (variable := candidates.pop()) not in seen:
seen.add(variable)
# check if the variable is independent, never go above blockers;
# blockers are independent variables and ancestors to include
elif variable in ancestors_to_include:
# The case where variable is in ancestors to include so we check if it depends on others
# it should be removed from the blockers to check against the rest
dependent = variable_depends_on(variable, ancestors_to_include - {variable})
if variable in ancestors_to_include:
# ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs
truncated_inputs.append(variable)
if dependent:
# if the ancestors to include is still dependent we need to go above, the search is not yet finished
# if the ancestors to include is still dependent on other ancestors we need to go above,
# FIXME: This seems wrong? The other ancestors above are either redundant given this variable,
# or another path leads to them and the special casing isn't needed
# It seems the only reason we are expanding on these inputs is to find other ancestors_to_include
# (instead of treating them as disconnected), but this may yet cause other unrelated variables
# to become "independent" in the process
if variable_depends_on(variable, ancestors_to_include - {variable}):
# owner can never be None for a dependent variable
candidates.extend(n for n in variable.owner.inputs if n not in seen)
candidates.extend(
n for n in variable.owner.inputs if n not in seen
)
else:
# A regular variable to check
dependent = variable_depends_on(variable, blockers)
# all regular variables fall to blockers
# 1. it is dependent - further search irrelevant
# 2. it is independent - the search variable is inside the closure
blockers.add(variable)
# if we've found an independent variable and it is not in blockers so far
# it is a new independent variable not present in ancestors to include
if dependent:
# populate search if it's not an independent variable
# owner can never be None for a dependent variable
candidates.extend(n for n in variable.owner.inputs if n not in seen)
if variable_depends_on(variable, blockers):
# If it's not an independent variable, inputs become candidates
candidates.extend(variable.owner.inputs)
else:
# otherwise, do not search beyond
# otherwise it's a truncated input itself
truncated_inputs.append(variable)
# add variable to seen, no point in checking it once more
seen.add(variable)
# all regular variables fall to blockers
# 1. it is dependent - we already expanded on the inputs, nothing to do if we find it again
# 2. it is independent - this is a truncated input, search for other nodes can stop here
blockers.add(variable)
except IndexError: # pop from an empty list
pass
return truncated_inputs
@overload
def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], OrderedSet | list[T]],
compute_deps_cache: None,
deps_cache: None,
clients: dict[T, list[T]] | None,
) -> list[T]: ...
def walk_toposort(
graphs: Iterable[T],
deps: Callable[[T], Iterable[T] | None],
) -> Generator[T, None, None]:
"""Perform a topological sort of all nodes starting from a given node.
Parameters
----------
graphs:
An iterable of nodes from which to start the topological sort.
deps : callable
A Python function that takes a node as input and returns its dependence.
@overload
def general_toposort(
outputs: Iterable[T],
deps: None,
compute_deps_cache: Callable[[T], OrderedSet | list[T] | None],
deps_cache: dict[T, list[T]] | None,
clients: dict[T, list[T]] | None,
) -> list[T]: ...
Notes
-----
``deps(i)`` should behave like a pure function (no funny business with internal state).
The order of the return value list is determined by the order of nodes
returned by the `deps` function.
"""
# Cache the dependencies (ancestors) as we iterate over the nodes with the deps function
deps_cache: dict[T, list[T]] = {}
def compute_deps_cache(obj, deps_cache=deps_cache):
if obj in deps_cache:
return deps_cache[obj]
d = deps_cache[obj] = deps(obj) or []
return d
clients: dict[T, list[T]] = {}
sources: deque[T] = deque()
total_nodes = 0
for node, children in walk(
graphs, compute_deps_cache, bfs=False, return_children=True
):
total_nodes += 1
# Mypy doesn't know that toposort will not return `None` because of our `or []` in the `compute_deps_cache`
for child in children: # type: ignore
clients.setdefault(child, []).append(node)
if not deps_cache[node]:
# Add nodes without dependencies to the stack
sources.append(node)
rset: set[T] = set()
try:
while True:
if (node := sources.popleft()) not in rset:
yield node
total_nodes -= 1
rset.add(node)
# Iterate over each client node (that is, it depends on the current node)
for client in clients.get(node, []):
# Remove itself from the dependent (ancestor) list of each client
d = deps_cache[client] = [
a for a in deps_cache[client] if a is not node
]
if not d:
# If there are no dependencies left to visit for this node, add it to the stack
sources.append(client)
except IndexError:
pass
if total_nodes != 0:
raise ValueError("graph contains cycles")
def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], OrderedSet | list[T]] | None,
compute_deps_cache: Callable[[T], OrderedSet | list[T] | None] | None = None,
deps: Callable[[T], Iterable[T] | None],
compute_deps_cache: Callable[[T], Iterable[T] | None] | None = None,
deps_cache: dict[T, list[T]] | None = None,
clients: dict[T, list[T]] | None = None,
) -> list[T]:
......@@ -499,93 +595,117 @@ def general_toposort(
compute_deps_cache : optional
If provided, `deps_cache` should also be provided. This is a function like
`deps`, but that also caches its results in a ``dict`` passed as `deps_cache`.
deps_cache : dict
A ``dict`` mapping nodes to their children. This is populated by
`compute_deps_cache`.
clients : dict
If a ``dict`` is passed, it will be filled with a mapping of
nodes-to-clients for each node in the subgraph.
Notes
-----
This is a simple wrapper around `walk_toposort` for backwards compatibility
``deps(i)`` should behave like a pure function (no funny business with
internal state).
``deps(i)`` will be cached by this function (to be fast).
The order of the return value list is determined by the order of nodes
returned by the `deps` function.
"""
# TODO: Deprecate me later
if compute_deps_cache is not None:
raise ValueError("compute_deps_cache is no longer supported")
if deps_cache is not None:
raise ValueError("deps_cache is no longer supported")
if clients is not None:
raise ValueError("clients is no longer supported")
return list(walk_toposort(outputs, deps))
The second option removes a Python function call, and allows for more
specialized code, so it can be faster.
def toposort(
graphs: Iterable[Variable],
blockers: Iterable[Variable] | None = None,
) -> Generator[Apply, None, None]:
"""Topologically sort of Apply nodes between graphs (outputs) and blockers (inputs).
This is a streamlined version of `io_toposort_generator` when no additional ordering
constraints are needed.
"""
if compute_deps_cache is None:
if deps_cache is None:
deps_cache = {}
def _compute_deps_cache_(io):
if io not in deps_cache:
d = deps(io)
if d:
if not isinstance(d, list | OrderedSet):
raise TypeError(
"Non-deterministic collections found; make"
" toposort non-deterministic."
)
deps_cache[io] = list(d)
else:
deps_cache[io] = None
return d
# We can put blocker variables in computed, as we only return apply nodes
computed = set(blockers or ())
todo = list(graphs)
try:
while True:
if (cur := todo.pop()) not in computed and (apply := cur.owner) is not None:
uncomputed_inputs = tuple(
i
for i in apply.inputs
if (i not in computed and i.owner is not None)
)
if not uncomputed_inputs:
yield apply
computed.update(apply.outputs)
else:
return deps_cache[io]
todo.append(cur)
todo.extend(uncomputed_inputs)
except IndexError: # queue is empty
return
_compute_deps_cache = _compute_deps_cache_
else:
_compute_deps_cache = compute_deps_cache
def toposort_with_orderings(
graphs: Iterable[Variable],
*,
blockers: Iterable[Variable] | None = None,
orderings: dict[Apply, list[Apply]] | None = None,
) -> Generator[Apply, None, None]:
"""Perform topological of nodes between blocker (input) and graphs (output) variables with arbitrary extra orderings
if deps_cache is None:
raise ValueError("deps_cache cannot be None")
Extra orderings can be used to force sorting of variables that are not naturally related in the graph.
This can be used by inplace optimizations to ensure a variable is only destroyed after all other uses.
Those other uses show up as dependencies of the destroying node, in the orderings dictionary.
search_res: list[NodeAndChildren] = cast(
list[NodeAndChildren],
list(walk(outputs, _compute_deps_cache, bfs=False, return_children=True)),
)
_clients: dict[T, list[T]] = {}
sources: deque[T] = deque()
search_res_len = len(search_res)
for snode, children in search_res:
if children:
for child in children:
_clients.setdefault(child, []).append(snode)
if not deps_cache.get(snode):
sources.append(snode)
Parameters
----------
graphs : list or tuple of Variable instances
Graph inputs.
outputs : list or tuple of Apply instances
Graph outputs.
orderings : dict
Keys are `Apply` or `Variable` instances, values are lists of `Apply` or `Variable` instances.
if clients is not None:
clients.update(_clients)
"""
if not orderings:
# Faster branch
yield from toposort(graphs, blockers=blockers)
rset: set[T] = set()
rlist: list[T] = []
while sources:
node: T = sources.popleft()
if node not in rset:
rlist.append(node)
rset.add(node)
for client in _clients.get(node, []):
d = [a for a in deps_cache[client] if a is not node]
deps_cache[client] = d
if not d:
sources.append(client)
else:
# the inputs are used to decide where to stop expanding
if blockers:
if len(rlist) != search_res_len:
raise ValueError("graph contains cycles")
def compute_deps(obj, blocker_set=frozenset(blockers), orderings=orderings):
if obj in blocker_set:
return None
if isinstance(obj, Apply):
return [*obj.inputs, *orderings.get(obj, [])]
else:
if (apply := obj.owner) is not None:
return [apply, *orderings.get(apply, [])]
else:
return orderings.get(obj, [])
else:
# mypy doesn't like conditional functions with different signatures,
# but passing the globals as optional is faster
def compute_deps(obj, orderings=orderings): # type: ignore[misc]
if isinstance(obj, Apply):
return [*obj.inputs, *orderings.get(obj, [])]
else:
if (apply := obj.owner) is not None:
return [apply, *orderings.get(apply, [])]
else:
return orderings.get(obj, [])
return rlist
yield from (
apply
for apply in walk_toposort(graphs, deps=compute_deps)
# mypy doesn't understand that our generator will return both Apply and Variables
if isinstance(apply, Apply) # type: ignore
)
def io_toposort(
......@@ -594,7 +714,11 @@ def io_toposort(
orderings: dict[Apply, list[Apply]] | None = None,
clients: dict[Variable, list[Variable]] | None = None,
) -> list[Apply]:
"""Perform topological sort from input and output nodes.
"""Perform topological of nodes between input and output variables.
Notes
-----
This is just a wrapper around `toposort_with_extra_orderings` for backwards compatibility
Parameters
----------
......@@ -604,96 +728,16 @@ def io_toposort(
Graph outputs.
orderings : dict
Keys are `Apply` instances, values are lists of `Apply` instances.
clients : dict
If provided, it will be filled with mappings of nodes-to-clients for
each node in the subgraph that is sorted.
"""
if not orderings and clients is None: # ordering can be None or empty dict
# Specialized function that is faster when more then ~10 nodes
# when no ordering.
# Do a new stack implementation with the vm algo.
# This will change the order returned.
computed = set(inputs)
todo = [o.owner for o in reversed(outputs) if o.owner]
order = []
while todo:
cur = todo.pop()
if all(out in computed for out in cur.outputs):
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
order.append(cur)
else:
todo.append(cur)
todo.extend(
i.owner for i in cur.inputs if (i.owner and i not in computed)
)
return order
iset = set(inputs)
# TODO: Deprecate me later
if clients is not None:
raise ValueError("clients is no longer supported")
if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.
deps_cache: dict = {}
def compute_deps_cache(obj):
if obj in deps_cache:
return deps_cache[obj]
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
if rval:
deps_cache[obj] = list(rval)
else:
deps_cache[obj] = rval
else:
deps_cache[obj] = rval
return rval
topo = general_toposort(
outputs,
deps=None,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)
else:
# the inputs are used only here in the function that decides what
# 'predecessors' to explore
def compute_deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, None)
return rval
topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=None,
deps_cache=None,
clients=clients,
)
return [o for o in topo if isinstance(o, Apply)]
return list(toposort_with_orderings(outputs, blockers=inputs, orderings=orderings))
def get_var_by_name(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR"
graphs: Iterable[Variable], target_var_id: str
) -> tuple[Variable, ...]:
r"""Get variables in a graph using their names.
......@@ -712,21 +756,18 @@ def get_var_by_name(
"""
from pytensor.graph.op import HasInnerGraph
def expand(r) -> list[Variable] | None:
if not r.owner:
def expand(r: Variable) -> list[Variable] | None:
if (apply := r.owner) is not None:
if isinstance(apply.op, HasInnerGraph):
return [*apply.inputs, *apply.op.inner_outputs]
else:
# Mypy doesn't know these will never be None
return apply.inputs # type: ignore
else:
return None
res = list(r.owner.inputs)
if isinstance(r.owner.op, HasInnerGraph):
res.extend(r.owner.op.inner_outputs)
return res
results: tuple[Variable, ...] = ()
for var in walk(graphs, expand, False):
var = cast(Variable, var)
if target_var_id == var.name or target_var_id == var.auto_name:
results += (var,)
return results
return tuple(
var
for var in walk(graphs, expand)
if (target_var_id == var.name or target_var_id == var.auto_name)
)
......@@ -21,7 +21,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph, Op, StorageMapType
from pytensor.graph.traversal import graph_inputs, io_toposort
from pytensor.graph.traversal import graph_inputs, toposort
from pytensor.graph.utils import Scratchpad
......@@ -1103,7 +1103,7 @@ class PPrinter(Printer):
)
inv_updates = {b: a for (a, b) in updates.items()}
i = 1
for node in io_toposort([*inputs, *updates], [*outputs, *updates.values()]):
for node in toposort([*outputs, *updates.values()], [*inputs, *updates]):
for output in node.outputs:
if output in inv_updates:
name = str(inv_updates[output])
......
......@@ -13,7 +13,6 @@ from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.compile.function.types import deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph import ancestors, graph_inputs
from pytensor.graph.basic import (
Apply,
Constant,
......@@ -35,7 +34,11 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
from pytensor.graph.rewriting.utils import get_clients_at_depth
from pytensor.graph.traversal import apply_depends_on, io_toposort
from pytensor.graph.traversal import (
ancestors,
apply_depends_on,
graph_inputs,
)
from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError
from pytensor.raise_op import Assert
......@@ -220,7 +223,7 @@ def scan_push_out_non_seq(fgraph, node):
"""
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_topo = node.op.fgraph.toposort()
local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
......@@ -427,7 +430,7 @@ def scan_push_out_seq(fgraph, node):
"""
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_topo = node.op.fgraph.toposort()
local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
......@@ -840,22 +843,42 @@ def scan_push_out_add(fgraph, node):
# apply_ancestors(args.inner_outputs)
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(
node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info
add_of_dot_nodes = [
n
for n in op.fgraph.apply_nodes
if
(
# We have an Add
isinstance(n.op, Elemwise)
and isinstance(n.op.scalar_op, ps.Add)
and any(
(
# With a Dot input that's only used in the Add
n_inp.owner is not None
and isinstance(n_inp.owner.op, Dot)
and len(op.fgraph.clients[n_inp]) == 1
)
for n_inp in n.inputs
)
)
]
if not add_of_dot_nodes:
return False
clients = {}
local_fgraph_topo = io_toposort(
args.inner_inputs, args.inner_outputs, clients=clients
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of access
args = ScanArgs(
node.inputs,
node.outputs,
op.inner_inputs,
op.inner_outputs,
op.info,
clone=False,
)
for nd in local_fgraph_topo:
for nd in add_of_dot_nodes:
if (
isinstance(nd.op, Elemwise)
and isinstance(nd.op.scalar_op, ps.Add)
and nd.out in args.inner_out_sit_sot
nd.out in args.inner_out_sit_sot
# FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern
and inner_sitsot_only_last_step_used(fgraph, nd.out, args)
):
......@@ -863,27 +886,17 @@ def scan_push_out_add(fgraph, node):
# the add from a previous iteration of the inner function
sitsot_idx = args.inner_out_sit_sot.index(nd.out)
if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx]
assert dot_input.owner is not None and isinstance(
dot_input.owner.op, Dot
)
if (
dot_input.owner is not None
and isinstance(dot_input.owner.op, Dot)
and len(clients[dot_input]) == 1
and dot_input.owner.inputs[0].ndim == 2
and dot_input.owner.inputs[1].ndim == 2
and get_outer_ndim(dot_input.owner.inputs[0], args) == 3
get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and get_outer_ndim(dot_input.owner.inputs[1], args) == 3
):
# The optimization can be be applied in this case.
......
......@@ -59,7 +59,7 @@ import time
import numpy as np
from pytensor.graph.traversal import io_toposort
from pytensor.graph.traversal import toposort
from pytensor.tensor.rewriting.basic import register_specialize
......@@ -460,6 +460,9 @@ class GemmOptimizer(GraphRewriter):
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
nodelist = list(toposort(fgraph.outputs))
nodelist.reverse()
def on_import(new_node):
if new_node is not node:
nodelist.append(new_node)
......@@ -471,10 +474,8 @@ class GemmOptimizer(GraphRewriter):
while did_something:
nb_iter += 1
t0 = time.perf_counter()
nodelist = io_toposort(fgraph.inputs, fgraph.outputs)
time_toposort += time.perf_counter() - t0
did_something = False
nodelist.reverse()
for node in nodelist:
if not (
isinstance(node.op, Elemwise)
......
......@@ -50,22 +50,13 @@ class TestProfiling:
the_string = buf.getvalue()
lines1 = [l for l in the_string.split("\n") if "Max if linker" in l]
lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l]
if config.device == "cpu":
assert "CPU: 4112KB (4104KB)" in the_string, (lines1, lines2)
assert "CPU: 8204KB (8196KB)" in the_string, (lines1, lines2)
# NODE: The specific numbers can change for distinct (but correct) toposort orderings
# Update the test values if a different algorithm is used
assert "CPU: 4112KB (4112KB)" in the_string, (lines1, lines2)
assert "CPU: 8204KB (8204KB)" in the_string, (lines1, lines2)
assert "CPU: 8208KB" in the_string, (lines1, lines2)
assert (
"Minimum peak from all valid apply node order is 4104KB"
in the_string
), (lines1, lines2)
else:
assert "CPU: 16KB (16KB)" in the_string, (lines1, lines2)
assert "GPU: 8204KB (8204KB)" in the_string, (lines1, lines2)
assert "GPU: 12300KB (12300KB)" in the_string, (lines1, lines2)
assert "GPU: 8212KB" in the_string, (lines1, lines2)
assert (
"Minimum peak from all valid apply node order is 4116KB"
in the_string
"Minimum peak from all valid apply node order is 4104KB" in the_string
), (lines1, lines2)
finally:
......
......@@ -160,7 +160,7 @@ def test_KanrenRelationSub_dot():
assert expr_opt.owner.op == pt.add
assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot)
assert fgraph_opt.inputs[0] is A_pt
assert fgraph_opt.inputs[-1] is A_pt
assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A"
assert expr_opt.owner.inputs[1].owner.op == pt.add
assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot)
......
......@@ -56,7 +56,7 @@ class TestFunctionGraph:
with pytest.raises(TypeError, match="'Variable' object is not iterable"):
FunctionGraph(var1, [var2])
with pytest.raises(TypeError, match="'Variable' object is not reversible"):
with pytest.raises(TypeError, match="'Variable' object is not iterable"):
FunctionGraph([var1], var2)
with pytest.raises(
......
......@@ -28,7 +28,7 @@ class TestCloneReplace:
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2])
f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp
assert x in f2_inp
......@@ -65,7 +65,7 @@ class TestCloneReplace:
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
......@@ -83,7 +83,7 @@ class TestCloneReplace:
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
......
......@@ -4,13 +4,17 @@ from pytensor import Variable, shared
from pytensor import tensor as pt
from pytensor.graph import Apply, ancestors, graph_inputs
from pytensor.graph.traversal import (
apply_ancestors,
apply_depends_on,
explicit_graph_inputs,
general_toposort,
get_var_by_name,
io_toposort,
orphans_between,
toposort,
toposort_with_orderings,
truncated_graph_inputs,
variable_ancestors,
variable_depends_on,
vars_between,
walk,
......@@ -36,23 +40,17 @@ class TestToposort:
o2 = MyOp(o, r5)
o2.name = "o2"
clients = {}
res = general_toposort([o2], self.prenode, clients=clients)
assert clients == {
o2.owner: [o2],
o: [o2.owner],
r5: [o2.owner],
o.owner: [o],
r1: [o.owner],
r2: [o.owner],
}
res = general_toposort([o2], self.prenode)
assert res == [r5, r2, r1, o.owner, o, o2.owner, o2]
with pytest.raises(ValueError):
general_toposort(
[o2], self.prenode, compute_deps_cache=lambda x: None, deps_cache=None
)
def circular_dependency(obj):
if obj is o:
# o2 depends on o, so o cannot depend on o2
return [o2, *self.prenode(obj)]
return self.prenode(obj)
with pytest.raises(ValueError, match="graph contains cycles"):
general_toposort([o2], circular_dependency)
res = io_toposort([r5], [o2])
assert res == [o.owner, o2.owner]
......@@ -181,16 +179,16 @@ def test_ancestors():
res = ancestors([o2], blockers=None)
res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2]
assert res_list == [o2, o1, r2, r1, r3]
res = ancestors([o2], blockers=None)
assert r3 in res
assert o1 in res
res_list = list(res)
assert res_list == [o1, r1, r2]
assert res_list == [r2, r1, r3]
res = ancestors([o2], blockers=[o1])
res_list = list(res)
assert res_list == [o2, r3, o1]
assert res_list == [o2, o1, r3]
def test_graph_inputs():
......@@ -202,7 +200,7 @@ def test_graph_inputs():
res = graph_inputs([o2], blockers=None)
res_list = list(res)
assert res_list == [r3, r1, r2]
assert res_list == [r2, r1, r3]
def test_explicit_graph_inputs():
......@@ -231,7 +229,7 @@ def test_variables_and_orphans():
vars_res_list = list(vars_res)
orphans_res_list = list(orphans_res)
assert vars_res_list == [o2, o1, r3, r2, r1]
assert vars_res_list == [o2, o1, r2, r1, r3]
assert orphans_res_list == [r3]
......@@ -408,3 +406,37 @@ def test_get_var_by_name():
exp_res = igo.fgraph.outputs[0]
assert res == exp_res
@pytest.mark.parametrize(
"func",
[
lambda x: all(variable_ancestors([x])),
lambda x: all(variable_ancestors([x], blockers=[x.clone()])),
lambda x: all(apply_ancestors([x])),
lambda x: all(apply_ancestors([x], blockers=[x.clone()])),
lambda x: all(toposort([x])),
lambda x: all(toposort([x], blockers=[x.clone()])),
lambda x: all(toposort_with_orderings([x], orderings={x: []})),
lambda x: all(
toposort_with_orderings([x], blockers=[x.clone()], orderings={x: []})
),
],
ids=[
"variable_ancestors",
"variable_ancestors_with_blockers",
"apply_ancestors",
"apply_ancestors_with_blockers)",
"toposort",
"toposort_with_blockers",
"toposort_with_orderings",
"toposort_with_orderings_and_blockers",
],
)
def test_traversal_benchmark(func, benchmark):
r1 = MyVariable(1)
out = r1
for i in range(50):
out = MyOp(out, out)
benchmark(func, out)
from itertools import chain
import numpy as np
import pytest
......@@ -490,6 +492,7 @@ def test_inplace_taps(n_steps_constant):
if isinstance(node.op, Scan)
]
# Collect inner inputs we expect to be destroyed by the step function
# Scan reorders inputs internally, so we need to check its ordering
inner_inps = scan_op.fgraph.inputs
mit_sot_inps = scan_op.inner_mitsot(inner_inps)
......@@ -501,28 +504,22 @@ def test_inplace_taps(n_steps_constant):
]
[sit_sot_inp] = scan_op.inner_sitsot(inner_inps)
inner_outs = scan_op.fgraph.outputs
mit_sot_outs = scan_op.inner_mitsot_outs(inner_outs)
[sit_sot_out] = scan_op.inner_sitsot_outs(inner_outs)
[nit_sot_out] = scan_op.inner_nitsot_outs(inner_outs)
destroyed_inputs = []
for inner_out in scan_op.fgraph.outputs:
node = inner_out.owner
dm = node.op.destroy_map
if dm:
destroyed_inputs.extend(
node.inputs[idx] for idx in chain.from_iterable(dm.values())
)
if n_steps_constant:
assert mit_sot_outs[0].owner.op.destroy_map == {
0: [mit_sot_outs[0].owner.inputs.index(oldest_mit_sot_inps[0])]
}
assert mit_sot_outs[1].owner.op.destroy_map == {
0: [mit_sot_outs[1].owner.inputs.index(oldest_mit_sot_inps[1])]
}
assert sit_sot_out.owner.op.destroy_map == {
0: [sit_sot_out.owner.inputs.index(sit_sot_inp)]
}
assert len(destroyed_inputs) == 3
assert set(destroyed_inputs) == {*oldest_mit_sot_inps, sit_sot_inp}
else:
# This is not a feature, but a current limitation
# https://github.com/pymc-devs/pytensor/issues/1283
assert mit_sot_outs[0].owner.op.destroy_map == {}
assert mit_sot_outs[1].owner.op.destroy_map == {}
assert sit_sot_out.owner.op.destroy_map == {}
assert nit_sot_out.owner.op.destroy_map == {}
assert not destroyed_inputs
@pytest.mark.parametrize(
......
......@@ -1170,8 +1170,8 @@ class TestHyp2F1Grad:
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop)
]
assert scalar_loop_op1.nin == 10 + 3 * 2 # wrt=[0, 1]
assert scalar_loop_op2.nin == 10 + 3 * 1 # wrt=[2]
assert scalar_loop_op1.nin == 10 + 3 * 1 # wrt=[2]
assert scalar_loop_op2.nin == 10 + 3 * 2 # wrt=[0, 1]
else:
[scalar_loop_op] = [
node.op.scalar_op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论