提交 066307f0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster graph traversal functions

* Avoid reversing inputs as we traverse graph * Simplify io_toposort without ordering (and refactor into its own function) * Removes client side-effect on previous toposort functions * Remove duplicated logic across methods
上级 f1a2ac66
......@@ -5,7 +5,6 @@ import warnings
from collections.abc import (
Hashable,
Iterable,
Reversible,
Sequence,
)
from copy import copy
......@@ -961,7 +960,7 @@ def clone_node_and_cache(
def clone_get_equiv(
inputs: Iterable[Variable],
outputs: Reversible[Variable],
outputs: Iterable[Variable],
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
......@@ -1002,7 +1001,7 @@ def clone_get_equiv(
Keywords passed to `Apply.clone_with_new_inputs`.
"""
from pytensor.graph.traversal import io_toposort
from pytensor.graph.traversal import toposort
if memo is None:
memo = {}
......@@ -1018,7 +1017,7 @@ def clone_get_equiv(
memo.setdefault(input, input)
# go through the inputs -> outputs graph cloning as we go
for apply in io_toposort(inputs, outputs):
for apply in toposort(outputs, blockers=inputs):
for input in apply.inputs:
if input not in memo:
if not isinstance(input, Constant) and copy_orphans:
......
......@@ -10,7 +10,7 @@ import numpy as np
import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable
from pytensor.graph.traversal import io_toposort
from pytensor.graph.traversal import toposort
from pytensor.graph.utils import InconsistencyError
......@@ -340,11 +340,11 @@ class Feature:
class Bookkeeper(Feature):
def on_attach(self, fgraph):
for node in io_toposort(fgraph.inputs, fgraph.outputs):
for node in toposort(fgraph.outputs):
self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph):
for node in io_toposort(fgraph.inputs, fgraph.outputs):
for node in toposort(fgraph.outputs):
self.on_prune(fgraph, node, "Bookkeeper.detach")
......
......@@ -19,7 +19,8 @@ from pytensor.graph.op import Op
from pytensor.graph.traversal import (
applys_between,
graph_inputs,
io_toposort,
toposort,
toposort_with_orderings,
vars_between,
)
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
......@@ -366,7 +367,7 @@ class FunctionGraph(MetaObject):
# new nodes, so we use all variables we know of as if they were the
# input set. (The functions in the graph module only use the input set
# to know where to stop going down.)
new_nodes = io_toposort(self.variables, apply_node.outputs)
new_nodes = tuple(toposort(apply_node.outputs, blockers=self.variables))
if check:
for node in new_nodes:
......@@ -759,7 +760,7 @@ class FunctionGraph(MetaObject):
# No sorting is necessary
return list(self.apply_nodes)
return io_toposort(self.inputs, self.outputs, self.orderings())
return list(toposort_with_orderings(self.outputs, orderings=self.orderings()))
def orderings(self) -> dict[Apply, list[Apply]]:
"""Return a map of node to node evaluation dependencies.
......
......@@ -10,7 +10,10 @@ from pytensor.graph.basic import (
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.traversal import io_toposort, truncated_graph_inputs
from pytensor.graph.traversal import (
toposort,
truncated_graph_inputs,
)
ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable]
......@@ -295,7 +298,7 @@ def vectorize_graph(
new_inputs = [replace.get(inp, inp) for inp in inputs]
vect_vars = dict(zip(inputs, new_inputs, strict=True))
for node in io_toposort(inputs, seq_outputs):
for node in toposort(seq_outputs, blockers=inputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):
......
......@@ -27,7 +27,7 @@ from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
from pytensor.graph.traversal import applys_between, io_toposort, vars_between
from pytensor.graph.traversal import applys_between, toposort, vars_between
from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten
......@@ -2010,7 +2010,7 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
callback_before = fgraph.execute_callbacks_time
nb_nodes_start = len(fgraph.apply_nodes)
t0 = time.perf_counter()
q = deque(io_toposort(fgraph.inputs, start_from))
q = deque(toposort(start_from))
io_t = time.perf_counter() - t0
def importer(node):
......@@ -2341,7 +2341,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
changed |= apply_cleanup(iter_cleanup_sub_profs)
topo_t0 = time.perf_counter()
q = deque(io_toposort(fgraph.inputs, start_from))
q = deque(toposort(start_from))
io_toposort_timing.append(time.perf_counter() - topo_t0)
nb_nodes.append(len(q))
......
......@@ -21,7 +21,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph, Op, StorageMapType
from pytensor.graph.traversal import graph_inputs, io_toposort
from pytensor.graph.traversal import graph_inputs, toposort
from pytensor.graph.utils import Scratchpad
......@@ -1103,7 +1103,7 @@ class PPrinter(Printer):
)
inv_updates = {b: a for (a, b) in updates.items()}
i = 1
for node in io_toposort([*inputs, *updates], [*outputs, *updates.values()]):
for node in toposort([*outputs, *updates.values()], [*inputs, *updates]):
for output in node.outputs:
if output in inv_updates:
name = str(inv_updates[output])
......
......@@ -13,7 +13,6 @@ from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.compile.function.types import deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph import ancestors, graph_inputs
from pytensor.graph.basic import (
Apply,
Constant,
......@@ -35,7 +34,11 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
from pytensor.graph.rewriting.utils import get_clients_at_depth
from pytensor.graph.traversal import apply_depends_on, io_toposort
from pytensor.graph.traversal import (
ancestors,
apply_depends_on,
graph_inputs,
)
from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError
from pytensor.raise_op import Assert
......@@ -220,7 +223,7 @@ def scan_push_out_non_seq(fgraph, node):
"""
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_topo = node.op.fgraph.toposort()
local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
......@@ -427,7 +430,7 @@ def scan_push_out_seq(fgraph, node):
"""
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_topo = node.op.fgraph.toposort()
local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
......@@ -840,22 +843,42 @@ def scan_push_out_add(fgraph, node):
# apply_ancestors(args.inner_outputs)
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(
node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info
)
add_of_dot_nodes = [
n
for n in op.fgraph.apply_nodes
if
(
# We have an Add
isinstance(n.op, Elemwise)
and isinstance(n.op.scalar_op, ps.Add)
and any(
(
# With a Dot input that's only used in the Add
n_inp.owner is not None
and isinstance(n_inp.owner.op, Dot)
and len(op.fgraph.clients[n_inp]) == 1
)
for n_inp in n.inputs
)
)
]
clients = {}
local_fgraph_topo = io_toposort(
args.inner_inputs, args.inner_outputs, clients=clients
if not add_of_dot_nodes:
return False
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of access
args = ScanArgs(
node.inputs,
node.outputs,
op.inner_inputs,
op.inner_outputs,
op.info,
clone=False,
)
for nd in local_fgraph_topo:
for nd in add_of_dot_nodes:
if (
isinstance(nd.op, Elemwise)
and isinstance(nd.op.scalar_op, ps.Add)
and nd.out in args.inner_out_sit_sot
nd.out in args.inner_out_sit_sot
# FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern
and inner_sitsot_only_last_step_used(fgraph, nd.out, args)
):
......@@ -863,27 +886,17 @@ def scan_push_out_add(fgraph, node):
# the add from a previous iteration of the inner function
sitsot_idx = args.inner_out_sit_sot.index(nd.out)
if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx]
assert dot_input.owner is not None and isinstance(
dot_input.owner.op, Dot
)
if (
dot_input.owner is not None
and isinstance(dot_input.owner.op, Dot)
and len(clients[dot_input]) == 1
and dot_input.owner.inputs[0].ndim == 2
and dot_input.owner.inputs[1].ndim == 2
and get_outer_ndim(dot_input.owner.inputs[0], args) == 3
get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and get_outer_ndim(dot_input.owner.inputs[1], args) == 3
):
# The optimization can be be applied in this case.
......
......@@ -59,7 +59,7 @@ import time
import numpy as np
from pytensor.graph.traversal import io_toposort
from pytensor.graph.traversal import toposort
from pytensor.tensor.rewriting.basic import register_specialize
......@@ -460,6 +460,9 @@ class GemmOptimizer(GraphRewriter):
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
nodelist = list(toposort(fgraph.outputs))
nodelist.reverse()
def on_import(new_node):
if new_node is not node:
nodelist.append(new_node)
......@@ -471,10 +474,8 @@ class GemmOptimizer(GraphRewriter):
while did_something:
nb_iter += 1
t0 = time.perf_counter()
nodelist = io_toposort(fgraph.inputs, fgraph.outputs)
time_toposort += time.perf_counter() - t0
did_something = False
nodelist.reverse()
for node in nodelist:
if not (
isinstance(node.op, Elemwise)
......
......@@ -50,23 +50,14 @@ class TestProfiling:
the_string = buf.getvalue()
lines1 = [l for l in the_string.split("\n") if "Max if linker" in l]
lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l]
if config.device == "cpu":
assert "CPU: 4112KB (4104KB)" in the_string, (lines1, lines2)
assert "CPU: 8204KB (8196KB)" in the_string, (lines1, lines2)
assert "CPU: 8208KB" in the_string, (lines1, lines2)
assert (
"Minimum peak from all valid apply node order is 4104KB"
in the_string
), (lines1, lines2)
else:
assert "CPU: 16KB (16KB)" in the_string, (lines1, lines2)
assert "GPU: 8204KB (8204KB)" in the_string, (lines1, lines2)
assert "GPU: 12300KB (12300KB)" in the_string, (lines1, lines2)
assert "GPU: 8212KB" in the_string, (lines1, lines2)
assert (
"Minimum peak from all valid apply node order is 4116KB"
in the_string
), (lines1, lines2)
# NODE: The specific numbers can change for distinct (but correct) toposort orderings
# Update the test values if a different algorithm is used
assert "CPU: 4112KB (4112KB)" in the_string, (lines1, lines2)
assert "CPU: 8204KB (8204KB)" in the_string, (lines1, lines2)
assert "CPU: 8208KB" in the_string, (lines1, lines2)
assert (
"Minimum peak from all valid apply node order is 4104KB" in the_string
), (lines1, lines2)
finally:
config.profile = config1
......
......@@ -160,7 +160,7 @@ def test_KanrenRelationSub_dot():
assert expr_opt.owner.op == pt.add
assert isinstance(expr_opt.owner.inputs[0].owner.op, Dot)
assert fgraph_opt.inputs[0] is A_pt
assert fgraph_opt.inputs[-1] is A_pt
assert expr_opt.owner.inputs[0].owner.inputs[0].name == "A"
assert expr_opt.owner.inputs[1].owner.op == pt.add
assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot)
......
......@@ -56,7 +56,7 @@ class TestFunctionGraph:
with pytest.raises(TypeError, match="'Variable' object is not iterable"):
FunctionGraph(var1, [var2])
with pytest.raises(TypeError, match="'Variable' object is not reversible"):
with pytest.raises(TypeError, match="'Variable' object is not iterable"):
FunctionGraph([var1], var2)
with pytest.raises(
......
......@@ -28,7 +28,7 @@ class TestCloneReplace:
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2])
f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp
assert x in f2_inp
......@@ -65,7 +65,7 @@ class TestCloneReplace:
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
......@@ -83,7 +83,7 @@ class TestCloneReplace:
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
f2_inp = tuple(graph_inputs([f2]))
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
......
......@@ -4,13 +4,17 @@ from pytensor import Variable, shared
from pytensor import tensor as pt
from pytensor.graph import Apply, ancestors, graph_inputs
from pytensor.graph.traversal import (
apply_ancestors,
apply_depends_on,
explicit_graph_inputs,
general_toposort,
get_var_by_name,
io_toposort,
orphans_between,
toposort,
toposort_with_orderings,
truncated_graph_inputs,
variable_ancestors,
variable_depends_on,
vars_between,
walk,
......@@ -36,23 +40,17 @@ class TestToposort:
o2 = MyOp(o, r5)
o2.name = "o2"
clients = {}
res = general_toposort([o2], self.prenode, clients=clients)
assert clients == {
o2.owner: [o2],
o: [o2.owner],
r5: [o2.owner],
o.owner: [o],
r1: [o.owner],
r2: [o.owner],
}
res = general_toposort([o2], self.prenode)
assert res == [r5, r2, r1, o.owner, o, o2.owner, o2]
with pytest.raises(ValueError):
general_toposort(
[o2], self.prenode, compute_deps_cache=lambda x: None, deps_cache=None
)
def circular_dependency(obj):
if obj is o:
# o2 depends on o, so o cannot depend on o2
return [o2, *self.prenode(obj)]
return self.prenode(obj)
with pytest.raises(ValueError, match="graph contains cycles"):
general_toposort([o2], circular_dependency)
res = io_toposort([r5], [o2])
assert res == [o.owner, o2.owner]
......@@ -181,16 +179,16 @@ def test_ancestors():
res = ancestors([o2], blockers=None)
res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2]
assert res_list == [o2, o1, r2, r1, r3]
res = ancestors([o2], blockers=None)
assert r3 in res
assert o1 in res
res_list = list(res)
assert res_list == [o1, r1, r2]
assert res_list == [r2, r1, r3]
res = ancestors([o2], blockers=[o1])
res_list = list(res)
assert res_list == [o2, r3, o1]
assert res_list == [o2, o1, r3]
def test_graph_inputs():
......@@ -202,7 +200,7 @@ def test_graph_inputs():
res = graph_inputs([o2], blockers=None)
res_list = list(res)
assert res_list == [r3, r1, r2]
assert res_list == [r2, r1, r3]
def test_explicit_graph_inputs():
......@@ -231,7 +229,7 @@ def test_variables_and_orphans():
vars_res_list = list(vars_res)
orphans_res_list = list(orphans_res)
assert vars_res_list == [o2, o1, r3, r2, r1]
assert vars_res_list == [o2, o1, r2, r1, r3]
assert orphans_res_list == [r3]
......@@ -408,3 +406,37 @@ def test_get_var_by_name():
exp_res = igo.fgraph.outputs[0]
assert res == exp_res
@pytest.mark.parametrize(
"func",
[
lambda x: all(variable_ancestors([x])),
lambda x: all(variable_ancestors([x], blockers=[x.clone()])),
lambda x: all(apply_ancestors([x])),
lambda x: all(apply_ancestors([x], blockers=[x.clone()])),
lambda x: all(toposort([x])),
lambda x: all(toposort([x], blockers=[x.clone()])),
lambda x: all(toposort_with_orderings([x], orderings={x: []})),
lambda x: all(
toposort_with_orderings([x], blockers=[x.clone()], orderings={x: []})
),
],
ids=[
"variable_ancestors",
"variable_ancestors_with_blockers",
"apply_ancestors",
"apply_ancestors_with_blockers)",
"toposort",
"toposort_with_blockers",
"toposort_with_orderings",
"toposort_with_orderings_and_blockers",
],
)
def test_traversal_benchmark(func, benchmark):
r1 = MyVariable(1)
out = r1
for i in range(50):
out = MyOp(out, out)
benchmark(func, out)
from itertools import chain
import numpy as np
import pytest
......@@ -490,6 +492,7 @@ def test_inplace_taps(n_steps_constant):
if isinstance(node.op, Scan)
]
# Collect inner inputs we expect to be destroyed by the step function
# Scan reorders inputs internally, so we need to check its ordering
inner_inps = scan_op.fgraph.inputs
mit_sot_inps = scan_op.inner_mitsot(inner_inps)
......@@ -501,28 +504,22 @@ def test_inplace_taps(n_steps_constant):
]
[sit_sot_inp] = scan_op.inner_sitsot(inner_inps)
inner_outs = scan_op.fgraph.outputs
mit_sot_outs = scan_op.inner_mitsot_outs(inner_outs)
[sit_sot_out] = scan_op.inner_sitsot_outs(inner_outs)
[nit_sot_out] = scan_op.inner_nitsot_outs(inner_outs)
destroyed_inputs = []
for inner_out in scan_op.fgraph.outputs:
node = inner_out.owner
dm = node.op.destroy_map
if dm:
destroyed_inputs.extend(
node.inputs[idx] for idx in chain.from_iterable(dm.values())
)
if n_steps_constant:
assert mit_sot_outs[0].owner.op.destroy_map == {
0: [mit_sot_outs[0].owner.inputs.index(oldest_mit_sot_inps[0])]
}
assert mit_sot_outs[1].owner.op.destroy_map == {
0: [mit_sot_outs[1].owner.inputs.index(oldest_mit_sot_inps[1])]
}
assert sit_sot_out.owner.op.destroy_map == {
0: [sit_sot_out.owner.inputs.index(sit_sot_inp)]
}
assert len(destroyed_inputs) == 3
assert set(destroyed_inputs) == {*oldest_mit_sot_inps, sit_sot_inp}
else:
# This is not a feature, but a current limitation
# https://github.com/pymc-devs/pytensor/issues/1283
assert mit_sot_outs[0].owner.op.destroy_map == {}
assert mit_sot_outs[1].owner.op.destroy_map == {}
assert sit_sot_out.owner.op.destroy_map == {}
assert nit_sot_out.owner.op.destroy_map == {}
assert not destroyed_inputs
@pytest.mark.parametrize(
......
......@@ -1170,8 +1170,8 @@ class TestHyp2F1Grad:
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop)
]
assert scalar_loop_op1.nin == 10 + 3 * 2 # wrt=[0, 1]
assert scalar_loop_op2.nin == 10 + 3 * 1 # wrt=[2]
assert scalar_loop_op1.nin == 10 + 3 * 1 # wrt=[2]
assert scalar_loop_op2.nin == 10 + 3 * 2 # wrt=[0, 1]
else:
[scalar_loop_op] = [
node.op.scalar_op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论