Unverified 提交 e9a7d7ce authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: GitHub

Speedup `truncated_graph_inputs` (#394)

* add pytest-mock dependency * rename to node to variable
上级 673c1acc
...@@ -139,7 +139,7 @@ jobs: ...@@ -139,7 +139,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
shell: bash -l {0} shell: bash -l {0}
run: | run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but # numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but
# not numpy, even though scipy 1.7 requires numpy<1.23. When installing # not numpy, even though scipy 1.7 requires numpy<1.23. When installing
# PyTensor next, pip installs a lower version of numpy via the PyPI. # PyTensor next, pip installs a lower version of numpy via the PyPI.
......
...@@ -31,6 +31,7 @@ dependencies: ...@@ -31,6 +31,7 @@ dependencies:
- pytest-cov - pytest-cov
- pytest-xdist - pytest-xdist
- pytest-benchmark - pytest-benchmark
- pytest-mock
# For building docs # For building docs
- sphinx>=5.1.0,<6 - sphinx>=5.1.0,<6
- sphinx_rtd_theme - sphinx_rtd_theme
......
...@@ -86,6 +86,7 @@ tests = [ ...@@ -86,6 +86,7 @@ tests = [
"pytest-cov>=2.6.1", "pytest-cov>=2.6.1",
"coverage>=5.1", "coverage>=5.1",
"pytest-benchmark", "pytest-benchmark",
"pytest-mock",
] ]
rtd = [ rtd = [
"sphinx>=5.1.0,<6", "sphinx>=5.1.0,<6",
......
...@@ -1003,14 +1003,14 @@ def applys_between( ...@@ -1003,14 +1003,14 @@ def applys_between(
def truncated_graph_inputs( def truncated_graph_inputs(
outputs: Sequence[Variable], outputs: Sequence[Variable],
ancestors_to_include: Optional[Collection[Variable]] = None, ancestors_to_include: Optional[Collection[Variable]] = None,
) -> List[Variable]: ) -> list[Variable]:
"""Get the truncate graph inputs. """Get the truncate graph inputs.
Unlike :func:`graph_inputs` this function will return Unlike :func:`graph_inputs` this function will return
the closest nodes to outputs that do not depend on the closest variables to outputs that do not depend on
``ancestors_to_include``. So given all the returned ``ancestors_to_include``. So given all the returned
variables provided there is no missing node to variables provided there is no missing variable to
compute the output and all nodes are independent compute the output and all variables are independent
from each other. from each other.
Parameters Parameters
...@@ -1027,7 +1027,7 @@ def truncated_graph_inputs( ...@@ -1027,7 +1027,7 @@ def truncated_graph_inputs(
Examples Examples
-------- --------
The returned nodes marked in (parenthesis), ancestors nodes are ``c``, output nodes are ``o`` The returned variables marked in (parenthesis), ancestors variables are ``c``, output variables are ``o``
* No ancestors to include * No ancestors to include
...@@ -1047,7 +1047,7 @@ def truncated_graph_inputs( ...@@ -1047,7 +1047,7 @@ def truncated_graph_inputs(
(c) - (c) - o (c) - (c) - o
* Additional nodes are present * Additional variables are present
.. code-block:: .. code-block::
...@@ -1076,58 +1076,60 @@ def truncated_graph_inputs( ...@@ -1076,58 +1076,60 @@ def truncated_graph_inputs(
""" """
# simple case, no additional ancestors to include # simple case, no additional ancestors to include
truncated_inputs = list() truncated_inputs: list[Variable] = list()
# blockers have known independent nodes and ancestors to include # blockers have known independent variables and ancestors to include
candidates = list(outputs) candidates = list(outputs)
if not ancestors_to_include: # None or empty if not ancestors_to_include: # None or empty
# just filter out unique variables # just filter out unique variables
for node in candidates: for variable in candidates:
if node not in truncated_inputs: if variable not in truncated_inputs:
truncated_inputs.append(node) truncated_inputs.append(variable)
# no more actions are needed # no more actions are needed
return truncated_inputs return truncated_inputs
blockers: Set[Variable] = set(ancestors_to_include) blockers: set[Variable] = set(ancestors_to_include)
# enforce O(1) check for node in ancestors to include # variables that go here are under check already, do not repeat the loop for them
seen: set[Variable] = set()
# enforce O(1) check for variable in ancestors to include
ancestors_to_include = blockers.copy() ancestors_to_include = blockers.copy()
while candidates: while candidates:
# on any new candidate # on any new candidate
node = candidates.pop() variable = candidates.pop()
# we've looked into this variable already
# There was a repeated reference to this node, we have already investigated it if variable in seen:
if node in truncated_inputs:
continue continue
# check if the variable is independent, never go above blockers;
# check if the node is independent, never go above blockers; # blockers are independent variables and ancestors to include
# blockers are independent nodes and ancestors to include elif variable in ancestors_to_include:
if node in ancestors_to_include: # The case where variable is in ancestors to include so we check if it depends on others
# The case where node is in ancestors to include so we check if it depends on others
# it should be removed from the blockers to check against the rest # it should be removed from the blockers to check against the rest
dependent = variable_depends_on(node, ancestors_to_include - {node}) dependent = variable_depends_on(variable, ancestors_to_include - {variable})
# ancestors to include that are present in the graph (not disconnected) # ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs # should be added to truncated_inputs
truncated_inputs.append(node) truncated_inputs.append(variable)
if dependent: if dependent:
# if the ancestors to include is still dependent we need to go above, the search is not yet finished # if the ancestors to include is still dependent we need to go above, the search is not yet finished
# owner can never be None for a dependent node # owner can never be None for a dependent variable
candidates.extend(node.owner.inputs) candidates.extend(n for n in variable.owner.inputs if n not in seen)
else: else:
# A regular node to check # A regular variable to check
dependent = variable_depends_on(node, blockers) dependent = variable_depends_on(variable, blockers)
# all regular nodes fall to blockers # all regular variables fall to blockers
# 1. it is dependent - further search irrelevant # 1. it is dependent - further search irrelevant
# 2. it is independent - the search node is inside the closure # 2. it is independent - the search variable is inside the closure
blockers.add(node) blockers.add(variable)
# if we've found an independent node and it is not in blockers so far # if we've found an independent variable and it is not in blockers so far
# it is a new independent node not present in ancestors to include # it is a new independent variable not present in ancestors to include
if dependent: if dependent:
# populate search if it's not an independent node # populate search if it's not an independent variable
# owner can never be None for a dependent node # owner can never be None for a dependent variable
candidates.extend(node.owner.inputs) candidates.extend(n for n in variable.owner.inputs if n not in seen)
else: else:
# otherwise, do not search beyond # otherwise, do not search beyond
truncated_inputs.append(node) truncated_inputs.append(variable)
# add variable to seen, no point in checking it once more
seen.add(variable)
return truncated_inputs return truncated_inputs
......
...@@ -795,3 +795,18 @@ class TestTruncatedGraphInputs: ...@@ -795,3 +795,18 @@ class TestTruncatedGraphInputs:
o2.name = "o2" o2.name = "o2"
assert truncated_graph_inputs([o2], [trunc_inp]) == [trunc_inp, x] assert truncated_graph_inputs([o2], [trunc_inp]) == [trunc_inp, x]
def test_single_pass_per_node(self, mocker):
import pytensor.graph.basic
inspect = mocker.spy(pytensor.graph.basic, "variable_depends_on")
x = at.dmatrix("x")
m = x.shape[0][None, None]
f = x / m
w = x / m - f
truncated_graph_inputs([w], [x])
# make sure there were exactly the same calls as unique variables seen by the function
assert len(inspect.call_args_list) == len(
{a for ((a, b), kw) in inspect.call_args_list}
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论