Unverified 提交 e9a7d7ce authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: GitHub

Speedup `truncated_graph_inputs` (#394)

* add pytest-mock dependency * rename to node to variable
上级 673c1acc
......@@ -139,7 +139,7 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but
# not numpy, even though scipy 1.7 requires numpy<1.23. When installing
# PyTensor next, pip installs a lower version of numpy via the PyPI.
......
......@@ -31,6 +31,7 @@ dependencies:
- pytest-cov
- pytest-xdist
- pytest-benchmark
- pytest-mock
# For building docs
- sphinx>=5.1.0,<6
- sphinx_rtd_theme
......
......@@ -86,6 +86,7 @@ tests = [
"pytest-cov>=2.6.1",
"coverage>=5.1",
"pytest-benchmark",
"pytest-mock",
]
rtd = [
"sphinx>=5.1.0,<6",
......
......@@ -1003,14 +1003,14 @@ def applys_between(
def truncated_graph_inputs(
outputs: Sequence[Variable],
ancestors_to_include: Optional[Collection[Variable]] = None,
) -> List[Variable]:
) -> list[Variable]:
"""Get the truncate graph inputs.
Unlike :func:`graph_inputs` this function will return
the closest nodes to outputs that do not depend on
the closest variables to outputs that do not depend on
``ancestors_to_include``. So given all the returned
variables provided there is no missing node to
compute the output and all nodes are independent
variables provided there is no missing variable to
compute the output and all variables are independent
from each other.
Parameters
......@@ -1027,7 +1027,7 @@ def truncated_graph_inputs(
Examples
--------
The returned nodes marked in (parenthesis), ancestors nodes are ``c``, output nodes are ``o``
The returned variables marked in (parenthesis), ancestors variables are ``c``, output variables are ``o``
* No ancestors to include
......@@ -1047,7 +1047,7 @@ def truncated_graph_inputs(
(c) - (c) - o
* Additional nodes are present
* Additional variables are present
.. code-block::
......@@ -1076,58 +1076,60 @@ def truncated_graph_inputs(
"""
# simple case, no additional ancestors to include
truncated_inputs = list()
# blockers have known independent nodes and ancestors to include
truncated_inputs: list[Variable] = list()
# blockers have known independent variables and ancestors to include
candidates = list(outputs)
if not ancestors_to_include: # None or empty
# just filter out unique variables
for node in candidates:
if node not in truncated_inputs:
truncated_inputs.append(node)
for variable in candidates:
if variable not in truncated_inputs:
truncated_inputs.append(variable)
# no more actions are needed
return truncated_inputs
blockers: Set[Variable] = set(ancestors_to_include)
# enforce O(1) check for node in ancestors to include
blockers: set[Variable] = set(ancestors_to_include)
# variables that go here are under check already, do not repeat the loop for them
seen: set[Variable] = set()
# enforce O(1) check for variable in ancestors to include
ancestors_to_include = blockers.copy()
while candidates:
# on any new candidate
node = candidates.pop()
# There was a repeated reference to this node, we have already investigated it
if node in truncated_inputs:
variable = candidates.pop()
# we've looked into this variable already
if variable in seen:
continue
# check if the node is independent, never go above blockers;
# blockers are independent nodes and ancestors to include
if node in ancestors_to_include:
# The case where node is in ancestors to include so we check if it depends on others
# check if the variable is independent, never go above blockers;
# blockers are independent variables and ancestors to include
elif variable in ancestors_to_include:
# The case where variable is in ancestors to include so we check if it depends on others
# it should be removed from the blockers to check against the rest
dependent = variable_depends_on(node, ancestors_to_include - {node})
dependent = variable_depends_on(variable, ancestors_to_include - {variable})
# ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs
truncated_inputs.append(node)
truncated_inputs.append(variable)
if dependent:
# if the ancestors to include is still dependent we need to go above, the search is not yet finished
# owner can never be None for a dependent node
candidates.extend(node.owner.inputs)
# owner can never be None for a dependent variable
candidates.extend(n for n in variable.owner.inputs if n not in seen)
else:
# A regular node to check
dependent = variable_depends_on(node, blockers)
# all regular nodes fall to blockers
# A regular variable to check
dependent = variable_depends_on(variable, blockers)
# all regular variables fall to blockers
# 1. it is dependent - further search irrelevant
# 2. it is independent - the search node is inside the closure
blockers.add(node)
# if we've found an independent node and it is not in blockers so far
# it is a new independent node not present in ancestors to include
# 2. it is independent - the search variable is inside the closure
blockers.add(variable)
# if we've found an independent variable and it is not in blockers so far
# it is a new independent variable not present in ancestors to include
if dependent:
# populate search if it's not an independent node
# owner can never be None for a dependent node
candidates.extend(node.owner.inputs)
# populate search if it's not an independent variable
# owner can never be None for a dependent variable
candidates.extend(n for n in variable.owner.inputs if n not in seen)
else:
# otherwise, do not search beyond
truncated_inputs.append(node)
truncated_inputs.append(variable)
# add variable to seen, no point in checking it once more
seen.add(variable)
return truncated_inputs
......
......@@ -795,3 +795,18 @@ class TestTruncatedGraphInputs:
o2.name = "o2"
assert truncated_graph_inputs([o2], [trunc_inp]) == [trunc_inp, x]
def test_single_pass_per_node(self, mocker):
import pytensor.graph.basic
inspect = mocker.spy(pytensor.graph.basic, "variable_depends_on")
x = at.dmatrix("x")
m = x.shape[0][None, None]
f = x / m
w = x / m - f
truncated_graph_inputs([w], [x])
# make sure there were exactly the same calls as unique variables seen by the function
assert len(inspect.call_args_list) == len(
{a for ((a, b), kw) in inspect.call_args_list}
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论