提交 066307f0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster graph traversal functions

* Avoid reversing inputs as we traverse graph * Simplify io_toposort without ordering (and refactor into its own function) * Removes client side-effect on previous toposort functions * Remove duplicated logic across methods
上级 f1a2ac66
...@@ -5,7 +5,6 @@ import warnings ...@@ -5,7 +5,6 @@ import warnings
from collections.abc import ( from collections.abc import (
Hashable, Hashable,
Iterable, Iterable,
Reversible,
Sequence, Sequence,
) )
from copy import copy from copy import copy
...@@ -961,7 +960,7 @@ def clone_node_and_cache( ...@@ -961,7 +960,7 @@ def clone_node_and_cache(
def clone_get_equiv( def clone_get_equiv(
inputs: Iterable[Variable], inputs: Iterable[Variable],
outputs: Reversible[Variable], outputs: Iterable[Variable],
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: bool = True, copy_orphans: bool = True,
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
...@@ -1002,7 +1001,7 @@ def clone_get_equiv( ...@@ -1002,7 +1001,7 @@ def clone_get_equiv(
Keywords passed to `Apply.clone_with_new_inputs`. Keywords passed to `Apply.clone_with_new_inputs`.
""" """
from pytensor.graph.traversal import io_toposort from pytensor.graph.traversal import toposort
if memo is None: if memo is None:
memo = {} memo = {}
...@@ -1018,7 +1017,7 @@ def clone_get_equiv( ...@@ -1018,7 +1017,7 @@ def clone_get_equiv(
memo.setdefault(input, input) memo.setdefault(input, input)
# go through the inputs -> outputs graph cloning as we go # go through the inputs -> outputs graph cloning as we go
for apply in io_toposort(inputs, outputs): for apply in toposort(outputs, blockers=inputs):
for input in apply.inputs: for input in apply.inputs:
if input not in memo: if input not in memo:
if not isinstance(input, Constant) and copy_orphans: if not isinstance(input, Constant) and copy_orphans:
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import pytensor import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.graph.traversal import io_toposort from pytensor.graph.traversal import toposort
from pytensor.graph.utils import InconsistencyError from pytensor.graph.utils import InconsistencyError
...@@ -340,11 +340,11 @@ class Feature: ...@@ -340,11 +340,11 @@ class Feature:
class Bookkeeper(Feature): class Bookkeeper(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
for node in io_toposort(fgraph.inputs, fgraph.outputs): for node in toposort(fgraph.outputs):
self.on_import(fgraph, node, "on_attach") self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph): def on_detach(self, fgraph):
for node in io_toposort(fgraph.inputs, fgraph.outputs): for node in toposort(fgraph.outputs):
self.on_prune(fgraph, node, "Bookkeeper.detach") self.on_prune(fgraph, node, "Bookkeeper.detach")
......
...@@ -19,7 +19,8 @@ from pytensor.graph.op import Op ...@@ -19,7 +19,8 @@ from pytensor.graph.op import Op
from pytensor.graph.traversal import ( from pytensor.graph.traversal import (
applys_between, applys_between,
graph_inputs, graph_inputs,
io_toposort, toposort,
toposort_with_orderings,
vars_between, vars_between,
) )
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
...@@ -366,7 +367,7 @@ class FunctionGraph(MetaObject): ...@@ -366,7 +367,7 @@ class FunctionGraph(MetaObject):
# new nodes, so we use all variables we know of as if they were the # new nodes, so we use all variables we know of as if they were the
# input set. (The functions in the graph module only use the input set # input set. (The functions in the graph module only use the input set
# to know where to stop going down.) # to know where to stop going down.)
new_nodes = io_toposort(self.variables, apply_node.outputs) new_nodes = tuple(toposort(apply_node.outputs, blockers=self.variables))
if check: if check:
for node in new_nodes: for node in new_nodes:
...@@ -759,7 +760,7 @@ class FunctionGraph(MetaObject): ...@@ -759,7 +760,7 @@ class FunctionGraph(MetaObject):
# No sorting is necessary # No sorting is necessary
return list(self.apply_nodes) return list(self.apply_nodes)
return io_toposort(self.inputs, self.outputs, self.orderings()) return list(toposort_with_orderings(self.outputs, orderings=self.orderings()))
def orderings(self) -> dict[Apply, list[Apply]]: def orderings(self) -> dict[Apply, list[Apply]]:
"""Return a map of node to node evaluation dependencies. """Return a map of node to node evaluation dependencies.
......
...@@ -10,7 +10,10 @@ from pytensor.graph.basic import ( ...@@ -10,7 +10,10 @@ from pytensor.graph.basic import (
) )
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.traversal import io_toposort, truncated_graph_inputs from pytensor.graph.traversal import (
toposort,
truncated_graph_inputs,
)
ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable] ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable]
...@@ -295,7 +298,7 @@ def vectorize_graph( ...@@ -295,7 +298,7 @@ def vectorize_graph(
new_inputs = [replace.get(inp, inp) for inp in inputs] new_inputs = [replace.get(inp, inp) for inp in inputs]
vect_vars = dict(zip(inputs, new_inputs, strict=True)) vect_vars = dict(zip(inputs, new_inputs, strict=True))
for node in io_toposort(inputs, seq_outputs): for node in toposort(seq_outputs, blockers=inputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs] vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs) vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True): for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):
......
...@@ -27,7 +27,7 @@ from pytensor.graph.features import AlreadyThere, Feature ...@@ -27,7 +27,7 @@ from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
from pytensor.graph.traversal import applys_between, io_toposort, vars_between from pytensor.graph.traversal import applys_between, toposort, vars_between
from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten from pytensor.utils import flatten
...@@ -2010,7 +2010,7 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2010,7 +2010,7 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
nb_nodes_start = len(fgraph.apply_nodes) nb_nodes_start = len(fgraph.apply_nodes)
t0 = time.perf_counter() t0 = time.perf_counter()
q = deque(io_toposort(fgraph.inputs, start_from)) q = deque(toposort(start_from))
io_t = time.perf_counter() - t0 io_t = time.perf_counter() - t0
def importer(node): def importer(node):
...@@ -2341,7 +2341,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2341,7 +2341,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
changed |= apply_cleanup(iter_cleanup_sub_profs) changed |= apply_cleanup(iter_cleanup_sub_profs)
topo_t0 = time.perf_counter() topo_t0 = time.perf_counter()
q = deque(io_toposort(fgraph.inputs, start_from)) q = deque(toposort(start_from))
io_toposort_timing.append(time.perf_counter() - topo_t0) io_toposort_timing.append(time.perf_counter() - topo_t0)
nb_nodes.append(len(q)) nb_nodes.append(len(q))
......
...@@ -21,7 +21,7 @@ from pytensor.configdefaults import config ...@@ -21,7 +21,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph, Op, StorageMapType from pytensor.graph.op import HasInnerGraph, Op, StorageMapType
from pytensor.graph.traversal import graph_inputs, io_toposort from pytensor.graph.traversal import graph_inputs, toposort
from pytensor.graph.utils import Scratchpad from pytensor.graph.utils import Scratchpad
...@@ -1103,7 +1103,7 @@ class PPrinter(Printer): ...@@ -1103,7 +1103,7 @@ class PPrinter(Printer):
) )
inv_updates = {b: a for (a, b) in updates.items()} inv_updates = {b: a for (a, b) in updates.items()}
i = 1 i = 1
for node in io_toposort([*inputs, *updates], [*outputs, *updates.values()]): for node in toposort([*outputs, *updates.values()], [*inputs, *updates]):
for output in node.outputs: for output in node.outputs:
if output in inv_updates: if output in inv_updates:
name = str(inv_updates[output]) name = str(inv_updates[output])
......
...@@ -13,7 +13,6 @@ from pytensor import tensor as pt ...@@ -13,7 +13,6 @@ from pytensor import tensor as pt
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.compile.function.types import deep_copy_op from pytensor.compile.function.types import deep_copy_op
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import ancestors, graph_inputs
from pytensor.graph.basic import ( from pytensor.graph.basic import (
Apply, Apply,
Constant, Constant,
...@@ -35,7 +34,11 @@ from pytensor.graph.rewriting.basic import ( ...@@ -35,7 +34,11 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
from pytensor.graph.rewriting.utils import get_clients_at_depth from pytensor.graph.rewriting.utils import get_clients_at_depth
from pytensor.graph.traversal import apply_depends_on, io_toposort from pytensor.graph.traversal import (
ancestors,
apply_depends_on,
graph_inputs,
)
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError from pytensor.graph.utils import InconsistencyError
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
...@@ -220,7 +223,7 @@ def scan_push_out_non_seq(fgraph, node): ...@@ -220,7 +223,7 @@ def scan_push_out_non_seq(fgraph, node):
""" """
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs) local_fgraph_topo = node.op.fgraph.toposort()
local_fgraph_outs_set = set(node_outputs) local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
...@@ -427,7 +430,7 @@ def scan_push_out_seq(fgraph, node): ...@@ -427,7 +430,7 @@ def scan_push_out_seq(fgraph, node):
""" """
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs) local_fgraph_topo = node.op.fgraph.toposort()
local_fgraph_outs_set = set(node_outputs) local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
...@@ -840,22 +843,42 @@ def scan_push_out_add(fgraph, node): ...@@ -840,22 +843,42 @@ def scan_push_out_add(fgraph, node):
# apply_ancestors(args.inner_outputs) # apply_ancestors(args.inner_outputs)
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of add_of_dot_nodes = [
# use n
args = ScanArgs( for n in op.fgraph.apply_nodes
node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info if
(
# We have an Add
isinstance(n.op, Elemwise)
and isinstance(n.op.scalar_op, ps.Add)
and any(
(
# With a Dot input that's only used in the Add
n_inp.owner is not None
and isinstance(n_inp.owner.op, Dot)
and len(op.fgraph.clients[n_inp]) == 1
)
for n_inp in n.inputs
) )
)
]
if not add_of_dot_nodes:
return False
clients = {} # Use `ScanArgs` to parse the inputs and outputs of scan for ease of access
local_fgraph_topo = io_toposort( args = ScanArgs(
args.inner_inputs, args.inner_outputs, clients=clients node.inputs,
node.outputs,
op.inner_inputs,
op.inner_outputs,
op.info,
clone=False,
) )
for nd in local_fgraph_topo: for nd in add_of_dot_nodes:
if ( if (
isinstance(nd.op, Elemwise) nd.out in args.inner_out_sit_sot
and isinstance(nd.op.scalar_op, ps.Add)
and nd.out in args.inner_out_sit_sot
# FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern # FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern
and inner_sitsot_only_last_step_used(fgraph, nd.out, args) and inner_sitsot_only_last_step_used(fgraph, nd.out, args)
): ):
...@@ -863,27 +886,17 @@ def scan_push_out_add(fgraph, node): ...@@ -863,27 +886,17 @@ def scan_push_out_add(fgraph, node):
# the add from a previous iteration of the inner function # the add from a previous iteration of the inner function
sitsot_idx = args.inner_out_sit_sot.index(nd.out) sitsot_idx = args.inner_out_sit_sot.index(nd.out)
if args.inner_in_sit_sot[sitsot_idx] in nd.inputs: if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx]) sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0 # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx = 1 - sitsot_in_idx dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx] dot_input = nd.inputs[dot_in_idx]
assert dot_input.owner is not None and isinstance(
dot_input.owner.op, Dot
)
if ( if (
dot_input.owner is not None get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and isinstance(dot_input.owner.op, Dot)
and len(clients[dot_input]) == 1
and dot_input.owner.inputs[0].ndim == 2
and dot_input.owner.inputs[1].ndim == 2
and get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and get_outer_ndim(dot_input.owner.inputs[1], args) == 3 and get_outer_ndim(dot_input.owner.inputs[1], args) == 3
): ):
# The optimization can be be applied in this case. # The optimization can be be applied in this case.
......
...@@ -59,7 +59,7 @@ import time ...@@ -59,7 +59,7 @@ import time
import numpy as np import numpy as np
from pytensor.graph.traversal import io_toposort from pytensor.graph.traversal import toposort
from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.basic import register_specialize
...@@ -460,6 +460,9 @@ class GemmOptimizer(GraphRewriter): ...@@ -460,6 +460,9 @@ class GemmOptimizer(GraphRewriter):
callbacks_before = fgraph.execute_callbacks_times.copy() callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
nodelist = list(toposort(fgraph.outputs))
nodelist.reverse()
def on_import(new_node): def on_import(new_node):
if new_node is not node: if new_node is not node:
nodelist.append(new_node) nodelist.append(new_node)
...@@ -471,10 +474,8 @@ class GemmOptimizer(GraphRewriter): ...@@ -471,10 +474,8 @@ class GemmOptimizer(GraphRewriter):
while did_something: while did_something:
nb_iter += 1 nb_iter += 1
t0 = time.perf_counter() t0 = time.perf_counter()
nodelist = io_toposort(fgraph.inputs, fgraph.outputs)
time_toposort += time.perf_counter() - t0 time_toposort += time.perf_counter() - t0
did_something = False did_something = False
nodelist.reverse()
for node in nodelist: for node in nodelist:
if not ( if not (
isinstance(node.op, Elemwise) isinstance(node.op, Elemwise)
......
...@@ -50,22 +50,13 @@ class TestProfiling: ...@@ -50,22 +50,13 @@ class TestProfiling:
the_string = buf.getvalue() the_string = buf.getvalue()
lines1 = [l for l in the_string.split("\n") if "Max if linker" in l] lines1 = [l for l in the_string.split("\n") if "Max if linker" in l]
lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l] lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l]
if config.device == "cpu": # NODE: The specific numbers can change for distinct (but correct) toposort orderings
assert "CPU: 4112KB (4104KB)" in the_string, (lines1, lines2) # Update the test values if a different algorithm is used
assert "CPU: 8204KB (8196KB)" in the_string, (lines1, lines2) assert "CPU: 4112KB (4112KB)" in the_string, (lines1, lines2)
assert "CPU: 8204KB (8204KB)" in the_string, (lines1, lines2)
assert "CPU: 8208KB" in the_string, (lines1, lines2) assert "CPU: 8208KB" in the_string, (lines1, lines2)
assert ( assert (
"Minimum peak from all valid apply node order is 4104KB" "Minimum peak from all valid apply node order is 4104KB" in the_string
in the_string
), (lines1, lines2)
else:
assert "CPU: 16KB (16KB)" in the_string, (lines1, lines2)
assert "GPU: 8204KB (8204KB)" in the_string, (lines1, lines2)
assert "GPU: 12300KB (12300KB)" in the_string, (lines1, lines2)
assert "GPU: 8212KB" in the_string, (lines1, lines2)
assert (
"Minimum peak from all valid apply node order is 4116KB"
in the_string
), (lines1, lines2) ), (lines1, lines2)
finally: finally:
......
...@@ -160,7 +160,7 @@ def test_KanrenRelationSub_dot(): ...@@ -160,7 +160,7 @@ def test_KanrenRelationSub_dot():
assert expr_opt.owner.op == pt.add assert expr_opt.owner.op == pt.add
assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot) assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot)
assert fgraph_opt.inputs[0] is A_pt assert fgraph_opt.inputs[-1] is A_pt
assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A" assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A"
assert expr_opt.owner.inputs[1].owner.op == pt.add assert expr_opt.owner.inputs[1].owner.op == pt.add
assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot) assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot)
......
...@@ -56,7 +56,7 @@ class TestFunctionGraph: ...@@ -56,7 +56,7 @@ class TestFunctionGraph:
with pytest.raises(TypeError, match="'Variable' object is not iterable"): with pytest.raises(TypeError, match="'Variable' object is not iterable"):
FunctionGraph(var1, [var2]) FunctionGraph(var1, [var2])
with pytest.raises(TypeError, match="'Variable' object is not reversible"): with pytest.raises(TypeError, match="'Variable' object is not iterable"):
FunctionGraph([var1], var2) FunctionGraph([var1], var2)
with pytest.raises( with pytest.raises(
......
...@@ -28,7 +28,7 @@ class TestCloneReplace: ...@@ -28,7 +28,7 @@ class TestCloneReplace:
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True) f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2]) f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -65,7 +65,7 @@ class TestCloneReplace: ...@@ -65,7 +65,7 @@ class TestCloneReplace:
f2 = clone_replace( f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
) )
f2_inp = graph_inputs([f2]) f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
assert y2 in f2_inp assert y2 in f2_inp
...@@ -83,7 +83,7 @@ class TestCloneReplace: ...@@ -83,7 +83,7 @@ class TestCloneReplace:
f2 = clone_replace( f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
) )
f2_inp = graph_inputs([f2]) f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
assert y2 in f2_inp assert y2 in f2_inp
......
...@@ -4,13 +4,17 @@ from pytensor import Variable, shared ...@@ -4,13 +4,17 @@ from pytensor import Variable, shared
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.graph import Apply, ancestors, graph_inputs from pytensor.graph import Apply, ancestors, graph_inputs
from pytensor.graph.traversal import ( from pytensor.graph.traversal import (
apply_ancestors,
apply_depends_on, apply_depends_on,
explicit_graph_inputs, explicit_graph_inputs,
general_toposort, general_toposort,
get_var_by_name, get_var_by_name,
io_toposort, io_toposort,
orphans_between, orphans_between,
toposort,
toposort_with_orderings,
truncated_graph_inputs, truncated_graph_inputs,
variable_ancestors,
variable_depends_on, variable_depends_on,
vars_between, vars_between,
walk, walk,
...@@ -36,23 +40,17 @@ class TestToposort: ...@@ -36,23 +40,17 @@ class TestToposort:
o2 = MyOp(o, r5) o2 = MyOp(o, r5)
o2.name = "o2" o2.name = "o2"
clients = {} res = general_toposort([o2], self.prenode)
res = general_toposort([o2], self.prenode, clients=clients)
assert clients == {
o2.owner: [o2],
o: [o2.owner],
r5: [o2.owner],
o.owner: [o],
r1: [o.owner],
r2: [o.owner],
}
assert res == [r5, r2, r1, o.owner, o, o2.owner, o2] assert res == [r5, r2, r1, o.owner, o, o2.owner, o2]
with pytest.raises(ValueError): def circular_dependency(obj):
general_toposort( if obj is o:
[o2], self.prenode, compute_deps_cache=lambda x: None, deps_cache=None # o2 depends on o, so o cannot depend on o2
) return [o2, *self.prenode(obj)]
return self.prenode(obj)
with pytest.raises(ValueError, match="graph contains cycles"):
general_toposort([o2], circular_dependency)
res = io_toposort([r5], [o2]) res = io_toposort([r5], [o2])
assert res == [o.owner, o2.owner] assert res == [o.owner, o2.owner]
...@@ -181,16 +179,16 @@ def test_ancestors(): ...@@ -181,16 +179,16 @@ def test_ancestors():
res = ancestors([o2], blockers=None) res = ancestors([o2], blockers=None)
res_list = list(res) res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2] assert res_list == [o2, o1, r2, r1, r3]
res = ancestors([o2], blockers=None) res = ancestors([o2], blockers=None)
assert r3 in res assert o1 in res
res_list = list(res) res_list = list(res)
assert res_list == [o1, r1, r2] assert res_list == [r2, r1, r3]
res = ancestors([o2], blockers=[o1]) res = ancestors([o2], blockers=[o1])
res_list = list(res) res_list = list(res)
assert res_list == [o2, r3, o1] assert res_list == [o2, o1, r3]
def test_graph_inputs(): def test_graph_inputs():
...@@ -202,7 +200,7 @@ def test_graph_inputs(): ...@@ -202,7 +200,7 @@ def test_graph_inputs():
res = graph_inputs([o2], blockers=None) res = graph_inputs([o2], blockers=None)
res_list = list(res) res_list = list(res)
assert res_list == [r3, r1, r2] assert res_list == [r2, r1, r3]
def test_explicit_graph_inputs(): def test_explicit_graph_inputs():
...@@ -231,7 +229,7 @@ def test_variables_and_orphans(): ...@@ -231,7 +229,7 @@ def test_variables_and_orphans():
vars_res_list = list(vars_res) vars_res_list = list(vars_res)
orphans_res_list = list(orphans_res) orphans_res_list = list(orphans_res)
assert vars_res_list == [o2, o1, r3, r2, r1] assert vars_res_list == [o2, o1, r2, r1, r3]
assert orphans_res_list == [r3] assert orphans_res_list == [r3]
...@@ -408,3 +406,37 @@ def test_get_var_by_name(): ...@@ -408,3 +406,37 @@ def test_get_var_by_name():
exp_res = igo.fgraph.outputs[0] exp_res = igo.fgraph.outputs[0]
assert res == exp_res assert res == exp_res
@pytest.mark.parametrize(
"func",
[
lambda x: all(variable_ancestors([x])),
lambda x: all(variable_ancestors([x], blockers=[x.clone()])),
lambda x: all(apply_ancestors([x])),
lambda x: all(apply_ancestors([x], blockers=[x.clone()])),
lambda x: all(toposort([x])),
lambda x: all(toposort([x], blockers=[x.clone()])),
lambda x: all(toposort_with_orderings([x], orderings={x: []})),
lambda x: all(
toposort_with_orderings([x], blockers=[x.clone()], orderings={x: []})
),
],
ids=[
"variable_ancestors",
"variable_ancestors_with_blockers",
"apply_ancestors",
"apply_ancestors_with_blockers)",
"toposort",
"toposort_with_blockers",
"toposort_with_orderings",
"toposort_with_orderings_and_blockers",
],
)
def test_traversal_benchmark(func, benchmark):
r1 = MyVariable(1)
out = r1
for i in range(50):
out = MyOp(out, out)
benchmark(func, out)
from itertools import chain
import numpy as np import numpy as np
import pytest import pytest
...@@ -490,6 +492,7 @@ def test_inplace_taps(n_steps_constant): ...@@ -490,6 +492,7 @@ def test_inplace_taps(n_steps_constant):
if isinstance(node.op, Scan) if isinstance(node.op, Scan)
] ]
# Collect inner inputs we expect to be destroyed by the step function
# Scan reorders inputs internally, so we need to check its ordering # Scan reorders inputs internally, so we need to check its ordering
inner_inps = scan_op.fgraph.inputs inner_inps = scan_op.fgraph.inputs
mit_sot_inps = scan_op.inner_mitsot(inner_inps) mit_sot_inps = scan_op.inner_mitsot(inner_inps)
...@@ -501,28 +504,22 @@ def test_inplace_taps(n_steps_constant): ...@@ -501,28 +504,22 @@ def test_inplace_taps(n_steps_constant):
] ]
[sit_sot_inp] = scan_op.inner_sitsot(inner_inps) [sit_sot_inp] = scan_op.inner_sitsot(inner_inps)
inner_outs = scan_op.fgraph.outputs destroyed_inputs = []
mit_sot_outs = scan_op.inner_mitsot_outs(inner_outs) for inner_out in scan_op.fgraph.outputs:
[sit_sot_out] = scan_op.inner_sitsot_outs(inner_outs) node = inner_out.owner
[nit_sot_out] = scan_op.inner_nitsot_outs(inner_outs) dm = node.op.destroy_map
if dm:
destroyed_inputs.extend(
node.inputs[idx] for idx in chain.from_iterable(dm.values())
)
if n_steps_constant: if n_steps_constant:
assert mit_sot_outs[0].owner.op.destroy_map == { assert len(destroyed_inputs) == 3
0: [mit_sot_outs[0].owner.inputs.index(oldest_mit_sot_inps[0])] assert set(destroyed_inputs) == {*oldest_mit_sot_inps, sit_sot_inp}
}
assert mit_sot_outs[1].owner.op.destroy_map == {
0: [mit_sot_outs[1].owner.inputs.index(oldest_mit_sot_inps[1])]
}
assert sit_sot_out.owner.op.destroy_map == {
0: [sit_sot_out.owner.inputs.index(sit_sot_inp)]
}
else: else:
# This is not a feature, but a current limitation # This is not a feature, but a current limitation
# https://github.com/pymc-devs/pytensor/issues/1283 # https://github.com/pymc-devs/pytensor/issues/1283
assert mit_sot_outs[0].owner.op.destroy_map == {} assert not destroyed_inputs
assert mit_sot_outs[1].owner.op.destroy_map == {}
assert sit_sot_out.owner.op.destroy_map == {}
assert nit_sot_out.owner.op.destroy_map == {}
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -1170,8 +1170,8 @@ class TestHyp2F1Grad: ...@@ -1170,8 +1170,8 @@ class TestHyp2F1Grad:
if isinstance(node.op, Elemwise) if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop) and isinstance(node.op.scalar_op, ScalarLoop)
] ]
assert scalar_loop_op1.nin == 10 + 3 * 2 # wrt=[0, 1] assert scalar_loop_op1.nin == 10 + 3 * 1 # wrt=[2]
assert scalar_loop_op2.nin == 10 + 3 * 1 # wrt=[2] assert scalar_loop_op2.nin == 10 + 3 * 2 # wrt=[0, 1]
else: else:
[scalar_loop_op] = [ [scalar_loop_op] = [
node.op.scalar_op node.op.scalar_op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论