提交 066307f0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster graph traversal functions

* Avoid reversing inputs as we traverse graph * Simplify io_toposort without ordering (and refactor into its own function) * Removes client side-effect on previous toposort functions * Remove duplicated logic across methods
上级 f1a2ac66
...@@ -5,7 +5,6 @@ import warnings ...@@ -5,7 +5,6 @@ import warnings
from collections.abc import ( from collections.abc import (
Hashable, Hashable,
Iterable, Iterable,
Reversible,
Sequence, Sequence,
) )
from copy import copy from copy import copy
...@@ -961,7 +960,7 @@ def clone_node_and_cache( ...@@ -961,7 +960,7 @@ def clone_node_and_cache(
def clone_get_equiv( def clone_get_equiv(
inputs: Iterable[Variable], inputs: Iterable[Variable],
outputs: Reversible[Variable], outputs: Iterable[Variable],
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: bool = True, copy_orphans: bool = True,
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
...@@ -1002,7 +1001,7 @@ def clone_get_equiv( ...@@ -1002,7 +1001,7 @@ def clone_get_equiv(
Keywords passed to `Apply.clone_with_new_inputs`. Keywords passed to `Apply.clone_with_new_inputs`.
""" """
from pytensor.graph.traversal import io_toposort from pytensor.graph.traversal import toposort
if memo is None: if memo is None:
memo = {} memo = {}
...@@ -1018,7 +1017,7 @@ def clone_get_equiv( ...@@ -1018,7 +1017,7 @@ def clone_get_equiv(
memo.setdefault(input, input) memo.setdefault(input, input)
# go through the inputs -> outputs graph cloning as we go # go through the inputs -> outputs graph cloning as we go
for apply in io_toposort(inputs, outputs): for apply in toposort(outputs, blockers=inputs):
for input in apply.inputs: for input in apply.inputs:
if input not in memo: if input not in memo:
if not isinstance(input, Constant) and copy_orphans: if not isinstance(input, Constant) and copy_orphans:
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import pytensor import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.graph.traversal import io_toposort from pytensor.graph.traversal import toposort
from pytensor.graph.utils import InconsistencyError from pytensor.graph.utils import InconsistencyError
...@@ -340,11 +340,11 @@ class Feature: ...@@ -340,11 +340,11 @@ class Feature:
class Bookkeeper(Feature): class Bookkeeper(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for node in io_toposort(fgraph.inputs, fgraph.outputs): for node in toposort(fgraph.outputs):
self.on_import(fgraph, node, "on_attach") self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph): def on_detach(self, fgraph):
for node in io_toposort(fgraph.inputs, fgraph.outputs): for node in toposort(fgraph.outputs):
self.on_prune(fgraph, node, "Bookkeeper.detach") self.on_prune(fgraph, node, "Bookkeeper.detach")
......
...@@ -19,7 +19,8 @@ from pytensor.graph.op import Op ...@@ -19,7 +19,8 @@ from pytensor.graph.op import Op
from pytensor.graph.traversal import ( from pytensor.graph.traversal import (
applys_between, applys_between,
graph_inputs, graph_inputs,
io_toposort, toposort,
toposort_with_orderings,
vars_between, vars_between,
) )
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
...@@ -366,7 +367,7 @@ class FunctionGraph(MetaObject): ...@@ -366,7 +367,7 @@ class FunctionGraph(MetaObject):
# new nodes, so we use all variables we know of as if they were the # new nodes, so we use all variables we know of as if they were the
# input set. (The functions in the graph module only use the input set # input set. (The functions in the graph module only use the input set
# to know where to stop going down.) # to know where to stop going down.)
new_nodes = io_toposort(self.variables, apply_node.outputs) new_nodes = tuple(toposort(apply_node.outputs, blockers=self.variables))
if check: if check:
for node in new_nodes: for node in new_nodes:
...@@ -759,7 +760,7 @@ class FunctionGraph(MetaObject): ...@@ -759,7 +760,7 @@ class FunctionGraph(MetaObject):
# No sorting is necessary # No sorting is necessary
return list(self.apply_nodes) return list(self.apply_nodes)
return io_toposort(self.inputs, self.outputs, self.orderings()) return list(toposort_with_orderings(self.outputs, orderings=self.orderings()))
def orderings(self) -> dict[Apply, list[Apply]]: def orderings(self) -> dict[Apply, list[Apply]]:
"""Return a map of node to node evaluation dependencies. """Return a map of node to node evaluation dependencies.
......
...@@ -10,7 +10,10 @@ from pytensor.graph.basic import ( ...@@ -10,7 +10,10 @@ from pytensor.graph.basic import (
) )
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.traversal import io_toposort, truncated_graph_inputs from pytensor.graph.traversal import (
toposort,
truncated_graph_inputs,
)
ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable] ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable]
...@@ -295,7 +298,7 @@ def vectorize_graph( ...@@ -295,7 +298,7 @@ def vectorize_graph(
new_inputs = [replace.get(inp, inp) for inp in inputs] new_inputs = [replace.get(inp, inp) for inp in inputs]
vect_vars = dict(zip(inputs, new_inputs, strict=True)) vect_vars = dict(zip(inputs, new_inputs, strict=True))
for node in io_toposort(inputs, seq_outputs): for node in toposort(seq_outputs, blockers=inputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs] vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs) vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True): for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):
......
...@@ -27,7 +27,7 @@ from pytensor.graph.features import AlreadyThere, Feature ...@@ -27,7 +27,7 @@ from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
from pytensor.graph.traversal import applys_between, io_toposort, vars_between from pytensor.graph.traversal import applys_between, toposort, vars_between
from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten from pytensor.utils import flatten
...@@ -2010,7 +2010,7 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2010,7 +2010,7 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
nb_nodes_start = len(fgraph.apply_nodes) nb_nodes_start = len(fgraph.apply_nodes)
t0 = time.perf_counter() t0 = time.perf_counter()
q = deque(io_toposort(fgraph.inputs, start_from)) q = deque(toposort(start_from))
io_t = time.perf_counter() - t0 io_t = time.perf_counter() - t0
def importer(node): def importer(node):
...@@ -2341,7 +2341,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2341,7 +2341,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
changed |= apply_cleanup(iter_cleanup_sub_profs) changed |= apply_cleanup(iter_cleanup_sub_profs)
topo_t0 = time.perf_counter() topo_t0 = time.perf_counter()
q = deque(io_toposort(fgraph.inputs, start_from)) q = deque(toposort(start_from))
io_toposort_timing.append(time.perf_counter() - topo_t0) io_toposort_timing.append(time.perf_counter() - topo_t0)
nb_nodes.append(len(q)) nb_nodes.append(len(q))
......
from collections import deque from collections import deque
from collections.abc import ( from collections.abc import (
Callable, Callable,
Collection,
Generator, Generator,
Iterable, Iterable,
Iterator,
Reversible, Reversible,
Sequence, Sequence,
) )
from typing import ( from typing import (
Literal,
TypeVar, TypeVar,
cast,
overload, overload,
) )
from pytensor.graph.basic import Apply, Constant, Node, Variable from pytensor.graph.basic import Apply, Constant, Node, Variable
from pytensor.misc.ordered_set import OrderedSet
T = TypeVar("T", bound=Node) T = TypeVar("T", bound=Node)
NodeAndChildren = tuple[T, Iterable[T] | None] NodeAndChildren = tuple[T, Iterable[T] | None]
@overload
def walk(
nodes: Iterable[T],
expand: Callable[[T], Iterable[T] | None],
bfs: bool = True,
return_children: Literal[False] = False,
) -> Generator[T, None, None]: ...
@overload
def walk(
nodes: Iterable[T],
expand: Callable[[T], Iterable[T] | None],
bfs: bool,
return_children: Literal[True],
) -> Generator[NodeAndChildren, None, None]: ...
def walk( def walk(
nodes: Iterable[T], nodes: Iterable[T],
expand: Callable[[T], Iterable[T] | None], expand: Callable[[T], Iterable[T] | None],
bfs: bool = True, bfs: bool = True,
return_children: bool = False, return_children: bool = False,
hash_fn: Callable[[T], int] = id,
) -> Generator[T | NodeAndChildren, None, None]: ) -> Generator[T | NodeAndChildren, None, None]:
r"""Walk through a graph, either breadth- or depth-first. r"""Walk through a graph, either breadth- or depth-first.
...@@ -44,9 +58,6 @@ def walk( ...@@ -44,9 +58,6 @@ def walk(
return_children return_children
If ``True``, each output node will be accompanied by the output of If ``True``, each output node will be accompanied by the output of
`expand` (i.e. the corresponding child nodes). `expand` (i.e. the corresponding child nodes).
hash_fn
The function used to produce hashes of the elements in `nodes`.
The default is ``id``.
Notes Notes
----- -----
...@@ -55,39 +66,39 @@ def walk( ...@@ -55,39 +66,39 @@ def walk(
""" """
rval_set: set[T] = set()
nodes = deque(nodes) nodes = deque(nodes)
nodes_pop: Callable[[], T] = nodes.popleft if bfs else nodes.pop
rval_set: set[int] = set() node: T
new_nodes: Iterable[T] | None
nodes_pop: Callable[[], T] try:
if bfs:
nodes_pop = nodes.popleft
else:
nodes_pop = nodes.pop
while nodes:
node: T = nodes_pop()
node_hash: int = hash_fn(node)
if node_hash not in rval_set:
rval_set.add(node_hash)
new_nodes: Iterable[T] | None = expand(node)
if return_children: if return_children:
while True:
node = nodes_pop()
if node not in rval_set:
new_nodes = expand(node)
yield node, new_nodes yield node, new_nodes
rval_set.add(node)
if new_nodes:
nodes.extend(new_nodes)
else: else:
while True:
node = nodes_pop()
if node not in rval_set:
yield node yield node
rval_set.add(node)
new_nodes = expand(node)
if new_nodes: if new_nodes:
nodes.extend(new_nodes) nodes.extend(new_nodes)
except IndexError:
return None
def ancestors( def ancestors(
graphs: Iterable[Variable], blockers: Collection[Variable] | None = None graphs: Iterable[Variable],
blockers: Iterable[Variable] | None = None,
) -> Generator[Variable, None, None]: ) -> Generator[Variable, None, None]:
r"""Return the variables that contribute to those in given graphs (inclusive). r"""Return the variables that contribute to those in given graphs (inclusive), stopping at blockers.
Parameters Parameters
---------- ----------
...@@ -101,21 +112,52 @@ def ancestors( ...@@ -101,21 +112,52 @@ def ancestors(
Yields Yields
------ ------
`Variable`\s `Variable`\s
All input nodes, in the order found by a left-recursive depth-first All ancestor variables, in the order found by a right-recursive depth-first search
search started at the nodes in `graphs`. started at the variables in `graphs`.
""" """
def expand(r: Variable) -> Iterator[Variable] | None: seen = set()
if r.owner and (not blockers or r not in blockers): queue = list(graphs)
return reversed(r.owner.inputs) try:
return None if blockers:
blockers = frozenset(blockers)
while True:
if (var := queue.pop()) not in seen:
yield var
seen.add(var)
if var not in blockers and (apply := var.owner) is not None:
queue.extend(apply.inputs)
else:
while True:
if (var := queue.pop()) not in seen:
yield var
seen.add(var)
if (apply := var.owner) is not None:
queue.extend(apply.inputs)
except IndexError:
return
yield from cast(Generator[Variable, None, None], walk(graphs, expand, False))
variable_ancestors = ancestors
def apply_ancestors(
graphs: Iterable[Variable],
blockers: Iterable[Variable] | None = None,
) -> Generator[Apply, None, None]:
"""Return the Apply nodes that contribute to those in given graphs (inclusive)."""
seen = {None} # This filters out Variables without an owner
for var in ancestors(graphs, blockers):
# For multi-output nodes, we'll see multiple variables
# but we should only yield the Apply once
if (apply := var.owner) not in seen:
yield apply
seen.add(apply)
return
def graph_inputs( def graph_inputs(
graphs: Iterable[Variable], blockers: Collection[Variable] | None = None graphs: Iterable[Variable], blockers: Iterable[Variable] | None = None
) -> Generator[Variable, None, None]: ) -> Generator[Variable, None, None]:
r"""Return the inputs required to compute the given Variables. r"""Return the inputs required to compute the given Variables.
...@@ -130,11 +172,10 @@ def graph_inputs( ...@@ -130,11 +172,10 @@ def graph_inputs(
Yields Yields
------ ------
Input nodes with no owner, in the order found by a left-recursive Input nodes with no owner, in the order found by a breath first search started at the nodes in `graphs`.
depth-first search started at the nodes in `graphs`.
""" """
yield from (r for r in ancestors(graphs, blockers) if r.owner is None) yield from (var for var in ancestors(graphs, blockers) if var.owner is None)
def explicit_graph_inputs( def explicit_graph_inputs(
...@@ -177,12 +218,12 @@ def explicit_graph_inputs( ...@@ -177,12 +218,12 @@ def explicit_graph_inputs(
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
if isinstance(graph, Variable): if isinstance(graph, Variable):
graph = [graph] graph = (graph,)
return ( return (
v var
for v in graph_inputs(graph) for var in ancestors(graph)
if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable) if var.owner is None and not isinstance(var, Constant | SharedVariable)
) )
...@@ -191,6 +232,11 @@ def vars_between( ...@@ -191,6 +232,11 @@ def vars_between(
) -> Generator[Variable, None, None]: ) -> Generator[Variable, None, None]:
r"""Extract the `Variable`\s within the sub-graph between input and output nodes. r"""Extract the `Variable`\s within the sub-graph between input and output nodes.
Notes
-----
This function is like ancestors(outs, blockers=ins),
except it can also yield disconnected output variables from multi-output apply nodes.
Parameters Parameters
---------- ----------
ins ins
...@@ -207,20 +253,19 @@ def vars_between( ...@@ -207,20 +253,19 @@ def vars_between(
""" """
ins = set(ins) def expand(var: Variable, ins=frozenset(ins)) -> Iterable[Variable] | None:
if var.owner is not None and var not in ins:
def expand(r: Variable) -> Iterable[Variable] | None: return (*var.owner.inputs, *var.owner.outputs)
if r.owner and r not in ins:
return reversed(r.owner.inputs + r.owner.outputs)
return None return None
yield from cast(Generator[Variable, None, None], walk(outs, expand)) # With bfs = False, it iterates similarly to ancestors
yield from walk(outs, expand, bfs=False)
def orphans_between( def orphans_between(
ins: Collection[Variable], outs: Iterable[Variable] ins: Iterable[Variable], outs: Iterable[Variable]
) -> Generator[Variable, None, None]: ) -> Generator[Variable, None, None]:
r"""Extract the `Variable`\s not within the sub-graph between input and output nodes. r"""Extract the root `Variable`\s not within the sub-graph between input and output nodes.
Parameters Parameters
---------- ----------
...@@ -245,14 +290,23 @@ def orphans_between( ...@@ -245,14 +290,23 @@ def orphans_between(
[y] [y]
""" """
yield from (r for r in vars_between(ins, outs) if r.owner is None and r not in ins) ins = frozenset(ins)
yield from (
var
for var in vars_between(ins, outs)
if ((var.owner is None) and (var not in ins))
)
def applys_between( def applys_between(
ins: Collection[Variable], outs: Iterable[Variable] ins: Iterable[Variable], outs: Iterable[Variable]
) -> Generator[Apply, None, None]: ) -> Generator[Apply, None, None]:
r"""Extract the `Apply`\s contained within the sub-graph between given input and output variables. r"""Extract the `Apply`\s contained within the sub-graph between given input and output variables.
Notes
-----
This is identical to apply_ancestors(outs, blockers=ins)
Parameters Parameters
---------- ----------
ins : list ins : list
...@@ -268,12 +322,10 @@ def applys_between( ...@@ -268,12 +322,10 @@ def applys_between(
owners of the `Variable`\s in `ins`. owners of the `Variable`\s in `ins`.
""" """
yield from ( return apply_ancestors(outs, blockers=ins)
r.owner for r in vars_between(ins, outs) if r not in ins and r.owner is not None
)
def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool: def apply_depends_on(apply: Apply, depends_on: Apply | Iterable[Apply]) -> bool:
"""Determine if any `depends_on` is in the graph given by ``apply``. """Determine if any `depends_on` is in the graph given by ``apply``.
Parameters Parameters
...@@ -288,51 +340,47 @@ def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> boo ...@@ -288,51 +340,47 @@ def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> boo
bool bool
""" """
computed = set() if isinstance(depends_on, Apply):
todo = [apply] depends_on = frozenset((depends_on,))
if not isinstance(depends_on, Collection):
depends_on = {depends_on}
else:
depends_on = set(depends_on)
while todo:
cur = todo.pop()
if cur.outputs[0] in computed:
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
if cur in depends_on:
return True
else: else:
todo.append(cur) depends_on = frozenset(depends_on)
todo.extend(i.owner for i in cur.inputs if i.owner) return (apply in depends_on) or any(
return False apply in depends_on for apply in apply_ancestors(apply.inputs)
)
def variable_depends_on( def variable_depends_on(
variable: Variable, depends_on: Variable | Collection[Variable] variable: Variable, depends_on: Variable | Iterable[Variable]
) -> bool: ) -> bool:
"""Determine if any `depends_on` is in the graph given by ``variable``. """Determine if any `depends_on` is in the graph given by ``variable``.
Notes
-----
The interpretation of dependency is done at a variable level.
A variable may depend on some output variables from a multi-output apply node but not others.
Parameters Parameters
---------- ----------
variable: Variable variable: Variable
Node to check T to check
depends_on: Collection[Variable] depends_on: Iterable[Variable]
Nodes to check dependency on Nodes to check dependency on
Returns Returns
------- -------
bool bool
""" """
if not isinstance(depends_on, Collection): if isinstance(depends_on, Variable):
depends_on = {depends_on} depends_on_set = frozenset((depends_on,))
else: else:
depends_on = set(depends_on) depends_on_set = frozenset(depends_on)
return any(interim in depends_on for interim in ancestors([variable])) return any(var in depends_on_set for var in variable_ancestors([variable]))
def truncated_graph_inputs( def truncated_graph_inputs(
outputs: Sequence[Variable], outputs: Sequence[Variable],
ancestors_to_include: Collection[Variable] | None = None, ancestors_to_include: Iterable[Variable] | None = None,
) -> list[Variable]: ) -> list[Variable]:
"""Get the truncate graph inputs. """Get the truncate graph inputs.
...@@ -345,9 +393,9 @@ def truncated_graph_inputs( ...@@ -345,9 +393,9 @@ def truncated_graph_inputs(
Parameters Parameters
---------- ----------
outputs : Collection[Variable] outputs : Iterable[Variable]
Variable to get conditions for Variable to get conditions for
ancestors_to_include : Optional[Collection[Variable]] ancestors_to_include : Optional[Iterable[Variable]]
Additional ancestors to assume, by default None Additional ancestors to assume, by default None
Returns Returns
...@@ -405,88 +453,136 @@ def truncated_graph_inputs( ...@@ -405,88 +453,136 @@ def truncated_graph_inputs(
n - (c) - (o/c) n - (c) - (o/c)
""" """
# simple case, no additional ancestors to include
truncated_inputs: list[Variable] = list() truncated_inputs: list[Variable] = list()
# blockers have known independent variables and ancestors to include seen: set[Variable] = set()
candidates = list(outputs)
if not ancestors_to_include: # None or empty # simple case, no additional ancestors to include
if not ancestors_to_include:
# just filter out unique variables # just filter out unique variables
for variable in candidates: for variable in outputs:
if variable not in truncated_inputs: if variable not in seen:
seen.add(variable)
truncated_inputs.append(variable) truncated_inputs.append(variable)
# no more actions are needed
return truncated_inputs return truncated_inputs
# blockers have known independent variables and ancestors to include
blockers: set[Variable] = set(ancestors_to_include) blockers: set[Variable] = set(ancestors_to_include)
# variables that go here are under check already, do not repeat the loop for them
seen: set[Variable] = set()
# enforce O(1) check for variable in ancestors to include # enforce O(1) check for variable in ancestors to include
ancestors_to_include = blockers.copy() ancestors_to_include = blockers.copy()
candidates = list(outputs)
while candidates: try:
# on any new candidate while True:
variable = candidates.pop() if (variable := candidates.pop()) not in seen:
# we've looked into this variable already seen.add(variable)
if variable in seen:
continue
# check if the variable is independent, never go above blockers; # check if the variable is independent, never go above blockers;
# blockers are independent variables and ancestors to include # blockers are independent variables and ancestors to include
elif variable in ancestors_to_include: if variable in ancestors_to_include:
# The case where variable is in ancestors to include so we check if it depends on others
# it should be removed from the blockers to check against the rest
dependent = variable_depends_on(variable, ancestors_to_include - {variable})
# ancestors to include that are present in the graph (not disconnected) # ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs # should be added to truncated_inputs
truncated_inputs.append(variable) truncated_inputs.append(variable)
if dependent: # if the ancestors to include is still dependent on other ancestors we need to go above,
# if the ancestors to include is still dependent we need to go above, the search is not yet finished # FIXME: This seems wrong? The other ancestors above are either redundant given this variable,
# or another path leads to them and the special casing isn't needed
# It seems the only reason we are expanding on these inputs is to find other ancestors_to_include
# (instead of treating them as disconnected), but this may yet cause other unrelated variables
# to become "independent" in the process
if variable_depends_on(variable, ancestors_to_include - {variable}):
# owner can never be None for a dependent variable # owner can never be None for a dependent variable
candidates.extend(n for n in variable.owner.inputs if n not in seen) candidates.extend(
n for n in variable.owner.inputs if n not in seen
)
else: else:
# A regular variable to check # A regular variable to check
dependent = variable_depends_on(variable, blockers)
# all regular variables fall to blockers
# 1. it is dependent - further search irrelevant
# 2. it is independent - the search variable is inside the closure
blockers.add(variable)
# if we've found an independent variable and it is not in blockers so far # if we've found an independent variable and it is not in blockers so far
# it is a new independent variable not present in ancestors to include # it is a new independent variable not present in ancestors to include
if dependent: if variable_depends_on(variable, blockers):
# populate search if it's not an independent variable # If it's not an independent variable, inputs become candidates
# owner can never be None for a dependent variable candidates.extend(variable.owner.inputs)
candidates.extend(n for n in variable.owner.inputs if n not in seen)
else: else:
# otherwise, do not search beyond # otherwise it's a truncated input itself
truncated_inputs.append(variable) truncated_inputs.append(variable)
# add variable to seen, no point in checking it once more # all regular variables fall to blockers
seen.add(variable) # 1. it is dependent - we already expanded on the inputs, nothing to do if we find it again
# 2. it is independent - this is a truncated input, search for other nodes can stop here
blockers.add(variable)
except IndexError: # pop from an empty list
pass
return truncated_inputs return truncated_inputs
@overload def walk_toposort(
def general_toposort( graphs: Iterable[T],
outputs: Iterable[T], deps: Callable[[T], Iterable[T] | None],
deps: Callable[[T], OrderedSet | list[T]], ) -> Generator[T, None, None]:
compute_deps_cache: None, """Perform a topological sort of all nodes starting from a given node.
deps_cache: None,
clients: dict[T, list[T]] | None,
) -> list[T]: ...
Parameters
----------
graphs:
An iterable of nodes from which to start the topological sort.
deps : callable
A Python function that takes a node as input and returns its dependence.
@overload Notes
def general_toposort( -----
outputs: Iterable[T],
deps: None, ``deps(i)`` should behave like a pure function (no funny business with internal state).
compute_deps_cache: Callable[[T], OrderedSet | list[T] | None],
deps_cache: dict[T, list[T]] | None, The order of the return value list is determined by the order of nodes
clients: dict[T, list[T]] | None, returned by the `deps` function.
) -> list[T]: ... """
# Cache the dependencies (ancestors) as we iterate over the nodes with the deps function
deps_cache: dict[T, list[T]] = {}
def compute_deps_cache(obj, deps_cache=deps_cache):
if obj in deps_cache:
return deps_cache[obj]
d = deps_cache[obj] = deps(obj) or []
return d
clients: dict[T, list[T]] = {}
sources: deque[T] = deque()
total_nodes = 0
for node, children in walk(
graphs, compute_deps_cache, bfs=False, return_children=True
):
total_nodes += 1
# Mypy doesn't know that toposort will not return `None` because of our `or []` in the `compute_deps_cache`
for child in children: # type: ignore
clients.setdefault(child, []).append(node)
if not deps_cache[node]:
# Add nodes without dependencies to the stack
sources.append(node)
rset: set[T] = set()
try:
while True:
if (node := sources.popleft()) not in rset:
yield node
total_nodes -= 1
rset.add(node)
# Iterate over each client node (that is, it depends on the current node)
for client in clients.get(node, []):
# Remove itself from the dependent (ancestor) list of each client
d = deps_cache[client] = [
a for a in deps_cache[client] if a is not node
]
if not d:
# If there are no dependencies left to visit for this node, add it to the stack
sources.append(client)
except IndexError:
pass
if total_nodes != 0:
raise ValueError("graph contains cycles")
def general_toposort( def general_toposort(
outputs: Iterable[T], outputs: Iterable[T],
deps: Callable[[T], OrderedSet | list[T]] | None, deps: Callable[[T], Iterable[T] | None],
compute_deps_cache: Callable[[T], OrderedSet | list[T] | None] | None = None, compute_deps_cache: Callable[[T], Iterable[T] | None] | None = None,
deps_cache: dict[T, list[T]] | None = None, deps_cache: dict[T, list[T]] | None = None,
clients: dict[T, list[T]] | None = None, clients: dict[T, list[T]] | None = None,
) -> list[T]: ) -> list[T]:
...@@ -499,93 +595,117 @@ def general_toposort( ...@@ -499,93 +595,117 @@ def general_toposort(
compute_deps_cache : optional compute_deps_cache : optional
If provided, `deps_cache` should also be provided. This is a function like If provided, `deps_cache` should also be provided. This is a function like
`deps`, but that also caches its results in a ``dict`` passed as `deps_cache`. `deps`, but that also caches its results in a ``dict`` passed as `deps_cache`.
deps_cache : dict
A ``dict`` mapping nodes to their children. This is populated by
`compute_deps_cache`.
clients : dict
If a ``dict`` is passed, it will be filled with a mapping of
nodes-to-clients for each node in the subgraph.
Notes Notes
----- -----
This is a simple wrapper around `walk_toposort` for backwards compatibility
``deps(i)`` should behave like a pure function (no funny business with ``deps(i)`` should behave like a pure function (no funny business with
internal state). internal state).
``deps(i)`` will be cached by this function (to be fast).
The order of the return value list is determined by the order of nodes The order of the return value list is determined by the order of nodes
returned by the `deps` function. returned by the `deps` function.
"""
# TODO: Deprecate me later
if compute_deps_cache is not None:
raise ValueError("compute_deps_cache is no longer supported")
if deps_cache is not None:
raise ValueError("deps_cache is no longer supported")
if clients is not None:
raise ValueError("clients is no longer supported")
return list(walk_toposort(outputs, deps))
The second option removes a Python function call, and allows for more def toposort(
specialized code, so it can be faster. graphs: Iterable[Variable],
blockers: Iterable[Variable] | None = None,
) -> Generator[Apply, None, None]:
"""Topologically sort of Apply nodes between graphs (outputs) and blockers (inputs).
This is a streamlined version of `io_toposort_generator` when no additional ordering
constraints are needed.
""" """
if compute_deps_cache is None:
if deps_cache is None:
deps_cache = {}
def _compute_deps_cache_(io):
if io not in deps_cache:
d = deps(io)
if d:
if not isinstance(d, list | OrderedSet):
raise TypeError(
"Non-deterministic collections found; make"
" toposort non-deterministic."
)
deps_cache[io] = list(d)
else:
deps_cache[io] = None
return d # We can put blocker variables in computed, as we only return apply nodes
computed = set(blockers or ())
todo = list(graphs)
try:
while True:
if (cur := todo.pop()) not in computed and (apply := cur.owner) is not None:
uncomputed_inputs = tuple(
i
for i in apply.inputs
if (i not in computed and i.owner is not None)
)
if not uncomputed_inputs:
yield apply
computed.update(apply.outputs)
else: else:
return deps_cache[io] todo.append(cur)
todo.extend(uncomputed_inputs)
except IndexError: # queue is empty
return
_compute_deps_cache = _compute_deps_cache_
else: def toposort_with_orderings(
_compute_deps_cache = compute_deps_cache graphs: Iterable[Variable],
*,
blockers: Iterable[Variable] | None = None,
orderings: dict[Apply, list[Apply]] | None = None,
) -> Generator[Apply, None, None]:
"""Perform topological of nodes between blocker (input) and graphs (output) variables with arbitrary extra orderings
if deps_cache is None: Extra orderings can be used to force sorting of variables that are not naturally related in the graph.
raise ValueError("deps_cache cannot be None") This can be used by inplace optimizations to ensure a variable is only destroyed after all other uses.
Those other uses show up as dependencies of the destroying node, in the orderings dictionary.
search_res: list[NodeAndChildren] = cast(
list[NodeAndChildren],
list(walk(outputs, _compute_deps_cache, bfs=False, return_children=True)),
)
_clients: dict[T, list[T]] = {} Parameters
sources: deque[T] = deque() ----------
search_res_len = len(search_res) graphs : list or tuple of Variable instances
for snode, children in search_res: Graph inputs.
if children: outputs : list or tuple of Apply instances
for child in children: Graph outputs.
_clients.setdefault(child, []).append(snode) orderings : dict
if not deps_cache.get(snode): Keys are `Apply` or `Variable` instances, values are lists of `Apply` or `Variable` instances.
sources.append(snode)
if clients is not None: """
clients.update(_clients) if not orderings:
# Faster branch
yield from toposort(graphs, blockers=blockers)
rset: set[T] = set() else:
rlist: list[T] = [] # the inputs are used to decide where to stop expanding
while sources: if blockers:
node: T = sources.popleft()
if node not in rset:
rlist.append(node)
rset.add(node)
for client in _clients.get(node, []):
d = [a for a in deps_cache[client] if a is not node]
deps_cache[client] = d
if not d:
sources.append(client)
if len(rlist) != search_res_len: def compute_deps(obj, blocker_set=frozenset(blockers), orderings=orderings):
raise ValueError("graph contains cycles") if obj in blocker_set:
return None
if isinstance(obj, Apply):
return [*obj.inputs, *orderings.get(obj, [])]
else:
if (apply := obj.owner) is not None:
return [apply, *orderings.get(apply, [])]
else:
return orderings.get(obj, [])
else:
# mypy doesn't like conditional functions with different signatures,
# but passing the globals as optional is faster
def compute_deps(obj, orderings=orderings): # type: ignore[misc]
if isinstance(obj, Apply):
return [*obj.inputs, *orderings.get(obj, [])]
else:
if (apply := obj.owner) is not None:
return [apply, *orderings.get(apply, [])]
else:
return orderings.get(obj, [])
return rlist yield from (
apply
for apply in walk_toposort(graphs, deps=compute_deps)
# mypy doesn't understand that our generator will return both Apply and Variables
if isinstance(apply, Apply) # type: ignore
)
def io_toposort( def io_toposort(
...@@ -594,7 +714,11 @@ def io_toposort( ...@@ -594,7 +714,11 @@ def io_toposort(
orderings: dict[Apply, list[Apply]] | None = None, orderings: dict[Apply, list[Apply]] | None = None,
clients: dict[Variable, list[Variable]] | None = None, clients: dict[Variable, list[Variable]] | None = None,
) -> list[Apply]: ) -> list[Apply]:
"""Perform topological sort from input and output nodes. """Perform topological of nodes between input and output variables.
Notes
-----
This is just a wrapper around `toposort_with_extra_orderings` for backwards compatibility
Parameters Parameters
---------- ----------
...@@ -604,96 +728,16 @@ def io_toposort( ...@@ -604,96 +728,16 @@ def io_toposort(
Graph outputs. Graph outputs.
orderings : dict orderings : dict
Keys are `Apply` instances, values are lists of `Apply` instances. Keys are `Apply` instances, values are lists of `Apply` instances.
clients : dict
If provided, it will be filled with mappings of nodes-to-clients for
each node in the subgraph that is sorted.
""" """
if not orderings and clients is None: # ordering can be None or empty dict # TODO: Deprecate me later
# Specialized function that is faster when more then ~10 nodes if clients is not None:
# when no ordering. raise ValueError("clients is no longer supported")
# Do a new stack implementation with the vm algo.
# This will change the order returned.
computed = set(inputs)
todo = [o.owner for o in reversed(outputs) if o.owner]
order = []
while todo:
cur = todo.pop()
if all(out in computed for out in cur.outputs):
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
order.append(cur)
else:
todo.append(cur)
todo.extend(
i.owner for i in cur.inputs if (i.owner and i not in computed)
)
return order
iset = set(inputs)
if not orderings: # ordering can be None or empty dict return list(toposort_with_orderings(outputs, blockers=inputs, orderings=orderings))
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.
deps_cache: dict = {}
def compute_deps_cache(obj):
if obj in deps_cache:
return deps_cache[obj]
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
if rval:
deps_cache[obj] = list(rval)
else:
deps_cache[obj] = rval
else:
deps_cache[obj] = rval
return rval
topo = general_toposort(
outputs,
deps=None,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)
else:
# the inputs are used only here in the function that decides what
# 'predecessors' to explore
def compute_deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, None)
return rval
topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=None,
deps_cache=None,
clients=clients,
)
return [o for o in topo if isinstance(o, Apply)]
def get_var_by_name( def get_var_by_name(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR" graphs: Iterable[Variable], target_var_id: str
) -> tuple[Variable, ...]: ) -> tuple[Variable, ...]:
r"""Get variables in a graph using their names. r"""Get variables in a graph using their names.
...@@ -712,21 +756,18 @@ def get_var_by_name( ...@@ -712,21 +756,18 @@ def get_var_by_name(
""" """
from pytensor.graph.op import HasInnerGraph from pytensor.graph.op import HasInnerGraph
def expand(r) -> list[Variable] | None: def expand(r: Variable) -> list[Variable] | None:
if not r.owner: if (apply := r.owner) is not None:
if isinstance(apply.op, HasInnerGraph):
return [*apply.inputs, *apply.op.inner_outputs]
else:
# Mypy doesn't know these will never be None
return apply.inputs # type: ignore
else:
return None return None
res = list(r.owner.inputs) return tuple(
var
if isinstance(r.owner.op, HasInnerGraph): for var in walk(graphs, expand)
res.extend(r.owner.op.inner_outputs) if (target_var_id == var.name or target_var_id == var.auto_name)
)
return res
results: tuple[Variable, ...] = ()
for var in walk(graphs, expand, False):
var = cast(Variable, var)
if target_var_id == var.name or target_var_id == var.auto_name:
results += (var,)
return results
...@@ -21,7 +21,7 @@ from pytensor.configdefaults import config ...@@ -21,7 +21,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph, Op, StorageMapType from pytensor.graph.op import HasInnerGraph, Op, StorageMapType
from pytensor.graph.traversal import graph_inputs, io_toposort from pytensor.graph.traversal import graph_inputs, toposort
from pytensor.graph.utils import Scratchpad from pytensor.graph.utils import Scratchpad
...@@ -1103,7 +1103,7 @@ class PPrinter(Printer): ...@@ -1103,7 +1103,7 @@ class PPrinter(Printer):
) )
inv_updates = {b: a for (a, b) in updates.items()} inv_updates = {b: a for (a, b) in updates.items()}
i = 1 i = 1
for node in io_toposort([*inputs, *updates], [*outputs, *updates.values()]): for node in toposort([*outputs, *updates.values()], [*inputs, *updates]):
for output in node.outputs: for output in node.outputs:
if output in inv_updates: if output in inv_updates:
name = str(inv_updates[output]) name = str(inv_updates[output])
......
...@@ -13,7 +13,6 @@ from pytensor import tensor as pt ...@@ -13,7 +13,6 @@ from pytensor import tensor as pt
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.compile.function.types import deep_copy_op from pytensor.compile.function.types import deep_copy_op
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import ancestors, graph_inputs
from pytensor.graph.basic import ( from pytensor.graph.basic import (
Apply, Apply,
Constant, Constant,
...@@ -35,7 +34,11 @@ from pytensor.graph.rewriting.basic import ( ...@@ -35,7 +34,11 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
from pytensor.graph.rewriting.utils import get_clients_at_depth from pytensor.graph.rewriting.utils import get_clients_at_depth
from pytensor.graph.traversal import apply_depends_on, io_toposort from pytensor.graph.traversal import (
ancestors,
apply_depends_on,
graph_inputs,
)
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError from pytensor.graph.utils import InconsistencyError
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
...@@ -220,7 +223,7 @@ def scan_push_out_non_seq(fgraph, node): ...@@ -220,7 +223,7 @@ def scan_push_out_non_seq(fgraph, node):
""" """
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs) local_fgraph_topo = node.op.fgraph.toposort()
local_fgraph_outs_set = set(node_outputs) local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
...@@ -427,7 +430,7 @@ def scan_push_out_seq(fgraph, node): ...@@ -427,7 +430,7 @@ def scan_push_out_seq(fgraph, node):
""" """
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs) local_fgraph_topo = node.op.fgraph.toposort()
local_fgraph_outs_set = set(node_outputs) local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
...@@ -840,22 +843,42 @@ def scan_push_out_add(fgraph, node): ...@@ -840,22 +843,42 @@ def scan_push_out_add(fgraph, node):
# apply_ancestors(args.inner_outputs) # apply_ancestors(args.inner_outputs)
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of add_of_dot_nodes = [
# use n
args = ScanArgs( for n in op.fgraph.apply_nodes
node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info if
(
# We have an Add
isinstance(n.op, Elemwise)
and isinstance(n.op.scalar_op, ps.Add)
and any(
(
# With a Dot input that's only used in the Add
n_inp.owner is not None
and isinstance(n_inp.owner.op, Dot)
and len(op.fgraph.clients[n_inp]) == 1
)
for n_inp in n.inputs
) )
)
]
if not add_of_dot_nodes:
return False
clients = {} # Use `ScanArgs` to parse the inputs and outputs of scan for ease of access
local_fgraph_topo = io_toposort( args = ScanArgs(
args.inner_inputs, args.inner_outputs, clients=clients node.inputs,
node.outputs,
op.inner_inputs,
op.inner_outputs,
op.info,
clone=False,
) )
for nd in local_fgraph_topo: for nd in add_of_dot_nodes:
if ( if (
isinstance(nd.op, Elemwise) nd.out in args.inner_out_sit_sot
and isinstance(nd.op.scalar_op, ps.Add)
and nd.out in args.inner_out_sit_sot
# FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern # FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern
and inner_sitsot_only_last_step_used(fgraph, nd.out, args) and inner_sitsot_only_last_step_used(fgraph, nd.out, args)
): ):
...@@ -863,27 +886,17 @@ def scan_push_out_add(fgraph, node): ...@@ -863,27 +886,17 @@ def scan_push_out_add(fgraph, node):
# the add from a previous iteration of the inner function # the add from a previous iteration of the inner function
sitsot_idx = args.inner_out_sit_sot.index(nd.out) sitsot_idx = args.inner_out_sit_sot.index(nd.out)
if args.inner_in_sit_sot[sitsot_idx] in nd.inputs: if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx]) sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0 # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx = 1 - sitsot_in_idx dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx] dot_input = nd.inputs[dot_in_idx]
assert dot_input.owner is not None and isinstance(
dot_input.owner.op, Dot
)
if ( if (
dot_input.owner is not None get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and isinstance(dot_input.owner.op, Dot)
and len(clients[dot_input]) == 1
and dot_input.owner.inputs[0].ndim == 2
and dot_input.owner.inputs[1].ndim == 2
and get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and get_outer_ndim(dot_input.owner.inputs[1], args) == 3 and get_outer_ndim(dot_input.owner.inputs[1], args) == 3
): ):
# The optimization can be be applied in this case. # The optimization can be be applied in this case.
......
...@@ -59,7 +59,7 @@ import time ...@@ -59,7 +59,7 @@ import time
import numpy as np import numpy as np
from pytensor.graph.traversal import io_toposort from pytensor.graph.traversal import toposort
from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.basic import register_specialize
...@@ -460,6 +460,9 @@ class GemmOptimizer(GraphRewriter): ...@@ -460,6 +460,9 @@ class GemmOptimizer(GraphRewriter):
callbacks_before = fgraph.execute_callbacks_times.copy() callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
nodelist = list(toposort(fgraph.outputs))
nodelist.reverse()
def on_import(new_node): def on_import(new_node):
if new_node is not node: if new_node is not node:
nodelist.append(new_node) nodelist.append(new_node)
...@@ -471,10 +474,8 @@ class GemmOptimizer(GraphRewriter): ...@@ -471,10 +474,8 @@ class GemmOptimizer(GraphRewriter):
while did_something: while did_something:
nb_iter += 1 nb_iter += 1
t0 = time.perf_counter() t0 = time.perf_counter()
nodelist = io_toposort(fgraph.inputs, fgraph.outputs)
time_toposort += time.perf_counter() - t0 time_toposort += time.perf_counter() - t0
did_something = False did_something = False
nodelist.reverse()
for node in nodelist: for node in nodelist:
if not ( if not (
isinstance(node.op, Elemwise) isinstance(node.op, Elemwise)
......
...@@ -50,22 +50,13 @@ class TestProfiling: ...@@ -50,22 +50,13 @@ class TestProfiling:
the_string = buf.getvalue() the_string = buf.getvalue()
lines1 = [l for l in the_string.split("\n") if "Max if linker" in l] lines1 = [l for l in the_string.split("\n") if "Max if linker" in l]
lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l] lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l]
if config.device == "cpu": # NODE: The specific numbers can change for distinct (but correct) toposort orderings
assert "CPU: 4112KB (4104KB)" in the_string, (lines1, lines2) # Update the test values if a different algorithm is used
assert "CPU: 8204KB (8196KB)" in the_string, (lines1, lines2) assert "CPU: 4112KB (4112KB)" in the_string, (lines1, lines2)
assert "CPU: 8204KB (8204KB)" in the_string, (lines1, lines2)
assert "CPU: 8208KB" in the_string, (lines1, lines2) assert "CPU: 8208KB" in the_string, (lines1, lines2)
assert ( assert (
"Minimum peak from all valid apply node order is 4104KB" "Minimum peak from all valid apply node order is 4104KB" in the_string
in the_string
), (lines1, lines2)
else:
assert "CPU: 16KB (16KB)" in the_string, (lines1, lines2)
assert "GPU: 8204KB (8204KB)" in the_string, (lines1, lines2)
assert "GPU: 12300KB (12300KB)" in the_string, (lines1, lines2)
assert "GPU: 8212KB" in the_string, (lines1, lines2)
assert (
"Minimum peak from all valid apply node order is 4116KB"
in the_string
), (lines1, lines2) ), (lines1, lines2)
finally: finally:
......
...@@ -160,7 +160,7 @@ def test_KanrenRelationSub_dot(): ...@@ -160,7 +160,7 @@ def test_KanrenRelationSub_dot():
assert expr_opt.owner.op == pt.add assert expr_opt.owner.op == pt.add
assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot) assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot)
assert fgraph_opt.inputs[0] is A_pt assert fgraph_opt.inputs[-1] is A_pt
assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A" assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A"
assert expr_opt.owner.inputs[1].owner.op == pt.add assert expr_opt.owner.inputs[1].owner.op == pt.add
assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot) assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot)
......
...@@ -56,7 +56,7 @@ class TestFunctionGraph: ...@@ -56,7 +56,7 @@ class TestFunctionGraph:
with pytest.raises(TypeError, match="'Variable' object is not iterable"): with pytest.raises(TypeError, match="'Variable' object is not iterable"):
FunctionGraph(var1, [var2]) FunctionGraph(var1, [var2])
with pytest.raises(TypeError, match="'Variable' object is not reversible"): with pytest.raises(TypeError, match="'Variable' object is not iterable"):
FunctionGraph([var1], var2) FunctionGraph([var1], var2)
with pytest.raises( with pytest.raises(
......
...@@ -28,7 +28,7 @@ class TestCloneReplace: ...@@ -28,7 +28,7 @@ class TestCloneReplace:
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True) f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2]) f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -65,7 +65,7 @@ class TestCloneReplace: ...@@ -65,7 +65,7 @@ class TestCloneReplace:
f2 = clone_replace( f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
) )
f2_inp = graph_inputs([f2]) f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
assert y2 in f2_inp assert y2 in f2_inp
...@@ -83,7 +83,7 @@ class TestCloneReplace: ...@@ -83,7 +83,7 @@ class TestCloneReplace:
f2 = clone_replace( f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
) )
f2_inp = graph_inputs([f2]) f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
assert y2 in f2_inp assert y2 in f2_inp
......
...@@ -4,13 +4,17 @@ from pytensor import Variable, shared ...@@ -4,13 +4,17 @@ from pytensor import Variable, shared
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.graph import Apply, ancestors, graph_inputs from pytensor.graph import Apply, ancestors, graph_inputs
from pytensor.graph.traversal import ( from pytensor.graph.traversal import (
apply_ancestors,
apply_depends_on, apply_depends_on,
explicit_graph_inputs, explicit_graph_inputs,
general_toposort, general_toposort,
get_var_by_name, get_var_by_name,
io_toposort, io_toposort,
orphans_between, orphans_between,
toposort,
toposort_with_orderings,
truncated_graph_inputs, truncated_graph_inputs,
variable_ancestors,
variable_depends_on, variable_depends_on,
vars_between, vars_between,
walk, walk,
...@@ -36,23 +40,17 @@ class TestToposort: ...@@ -36,23 +40,17 @@ class TestToposort:
o2 = MyOp(o, r5) o2 = MyOp(o, r5)
o2.name = "o2" o2.name = "o2"
clients = {} res = general_toposort([o2], self.prenode)
res = general_toposort([o2], self.prenode, clients=clients)
assert clients == {
o2.owner: [o2],
o: [o2.owner],
r5: [o2.owner],
o.owner: [o],
r1: [o.owner],
r2: [o.owner],
}
assert res == [r5, r2, r1, o.owner, o, o2.owner, o2] assert res == [r5, r2, r1, o.owner, o, o2.owner, o2]
with pytest.raises(ValueError): def circular_dependency(obj):
general_toposort( if obj is o:
[o2], self.prenode, compute_deps_cache=lambda x: None, deps_cache=None # o2 depends on o, so o cannot depend on o2
) return [o2, *self.prenode(obj)]
return self.prenode(obj)
with pytest.raises(ValueError, match="graph contains cycles"):
general_toposort([o2], circular_dependency)
res = io_toposort([r5], [o2]) res = io_toposort([r5], [o2])
assert res == [o.owner, o2.owner] assert res == [o.owner, o2.owner]
...@@ -181,16 +179,16 @@ def test_ancestors(): ...@@ -181,16 +179,16 @@ def test_ancestors():
res = ancestors([o2], blockers=None) res = ancestors([o2], blockers=None)
res_list = list(res) res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2] assert res_list == [o2, o1, r2, r1, r3]
res = ancestors([o2], blockers=None) res = ancestors([o2], blockers=None)
assert r3 in res assert o1 in res
res_list = list(res) res_list = list(res)
assert res_list == [o1, r1, r2] assert res_list == [r2, r1, r3]
res = ancestors([o2], blockers=[o1]) res = ancestors([o2], blockers=[o1])
res_list = list(res) res_list = list(res)
assert res_list == [o2, r3, o1] assert res_list == [o2, o1, r3]
def test_graph_inputs(): def test_graph_inputs():
...@@ -202,7 +200,7 @@ def test_graph_inputs(): ...@@ -202,7 +200,7 @@ def test_graph_inputs():
res = graph_inputs([o2], blockers=None) res = graph_inputs([o2], blockers=None)
res_list = list(res) res_list = list(res)
assert res_list == [r3, r1, r2] assert res_list == [r2, r1, r3]
def test_explicit_graph_inputs(): def test_explicit_graph_inputs():
...@@ -231,7 +229,7 @@ def test_variables_and_orphans(): ...@@ -231,7 +229,7 @@ def test_variables_and_orphans():
vars_res_list = list(vars_res) vars_res_list = list(vars_res)
orphans_res_list = list(orphans_res) orphans_res_list = list(orphans_res)
assert vars_res_list == [o2, o1, r3, r2, r1] assert vars_res_list == [o2, o1, r2, r1, r3]
assert orphans_res_list == [r3] assert orphans_res_list == [r3]
...@@ -408,3 +406,37 @@ def test_get_var_by_name(): ...@@ -408,3 +406,37 @@ def test_get_var_by_name():
exp_res = igo.fgraph.outputs[0] exp_res = igo.fgraph.outputs[0]
assert res == exp_res assert res == exp_res
@pytest.mark.parametrize(
"func",
[
lambda x: all(variable_ancestors([x])),
lambda x: all(variable_ancestors([x], blockers=[x.clone()])),
lambda x: all(apply_ancestors([x])),
lambda x: all(apply_ancestors([x], blockers=[x.clone()])),
lambda x: all(toposort([x])),
lambda x: all(toposort([x], blockers=[x.clone()])),
lambda x: all(toposort_with_orderings([x], orderings={x: []})),
lambda x: all(
toposort_with_orderings([x], blockers=[x.clone()], orderings={x: []})
),
],
ids=[
"variable_ancestors",
"variable_ancestors_with_blockers",
"apply_ancestors",
"apply_ancestors_with_blockers)",
"toposort",
"toposort_with_blockers",
"toposort_with_orderings",
"toposort_with_orderings_and_blockers",
],
)
def test_traversal_benchmark(func, benchmark):
r1 = MyVariable(1)
out = r1
for i in range(50):
out = MyOp(out, out)
benchmark(func, out)
from itertools import chain
import numpy as np import numpy as np
import pytest import pytest
...@@ -490,6 +492,7 @@ def test_inplace_taps(n_steps_constant): ...@@ -490,6 +492,7 @@ def test_inplace_taps(n_steps_constant):
if isinstance(node.op, Scan) if isinstance(node.op, Scan)
] ]
# Collect inner inputs we expect to be destroyed by the step function
# Scan reorders inputs internally, so we need to check its ordering # Scan reorders inputs internally, so we need to check its ordering
inner_inps = scan_op.fgraph.inputs inner_inps = scan_op.fgraph.inputs
mit_sot_inps = scan_op.inner_mitsot(inner_inps) mit_sot_inps = scan_op.inner_mitsot(inner_inps)
...@@ -501,28 +504,22 @@ def test_inplace_taps(n_steps_constant): ...@@ -501,28 +504,22 @@ def test_inplace_taps(n_steps_constant):
] ]
[sit_sot_inp] = scan_op.inner_sitsot(inner_inps) [sit_sot_inp] = scan_op.inner_sitsot(inner_inps)
inner_outs = scan_op.fgraph.outputs destroyed_inputs = []
mit_sot_outs = scan_op.inner_mitsot_outs(inner_outs) for inner_out in scan_op.fgraph.outputs:
[sit_sot_out] = scan_op.inner_sitsot_outs(inner_outs) node = inner_out.owner
[nit_sot_out] = scan_op.inner_nitsot_outs(inner_outs) dm = node.op.destroy_map
if dm:
destroyed_inputs.extend(
node.inputs[idx] for idx in chain.from_iterable(dm.values())
)
if n_steps_constant: if n_steps_constant:
assert mit_sot_outs[0].owner.op.destroy_map == { assert len(destroyed_inputs) == 3
0: [mit_sot_outs[0].owner.inputs.index(oldest_mit_sot_inps[0])] assert set(destroyed_inputs) == {*oldest_mit_sot_inps, sit_sot_inp}
}
assert mit_sot_outs[1].owner.op.destroy_map == {
0: [mit_sot_outs[1].owner.inputs.index(oldest_mit_sot_inps[1])]
}
assert sit_sot_out.owner.op.destroy_map == {
0: [sit_sot_out.owner.inputs.index(sit_sot_inp)]
}
else: else:
# This is not a feature, but a current limitation # This is not a feature, but a current limitation
# https://github.com/pymc-devs/pytensor/issues/1283 # https://github.com/pymc-devs/pytensor/issues/1283
assert mit_sot_outs[0].owner.op.destroy_map == {} assert not destroyed_inputs
assert mit_sot_outs[1].owner.op.destroy_map == {}
assert sit_sot_out.owner.op.destroy_map == {}
assert nit_sot_out.owner.op.destroy_map == {}
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -1170,8 +1170,8 @@ class TestHyp2F1Grad: ...@@ -1170,8 +1170,8 @@ class TestHyp2F1Grad:
if isinstance(node.op, Elemwise) if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop) and isinstance(node.op.scalar_op, ScalarLoop)
] ]
assert scalar_loop_op1.nin == 10 + 3 * 2 # wrt=[0, 1] assert scalar_loop_op1.nin == 10 + 3 * 1 # wrt=[2]
assert scalar_loop_op2.nin == 10 + 3 * 1 # wrt=[2] assert scalar_loop_op2.nin == 10 + 3 * 2 # wrt=[0, 1]
else: else:
[scalar_loop_op] = [ [scalar_loop_op] = [
node.op.scalar_op node.op.scalar_op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论