提交 1b673567 authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Ricardo Vieira

add truncated_graph_inputs function

上级 5fdc130f
......@@ -970,6 +970,136 @@ def applys_between(
)
def truncated_graph_inputs(
outputs: Sequence[Variable],
ancestors_to_include: Optional[Collection[Variable]] = None,
) -> List[Variable]:
"""Get the truncate graph inputs.
Unlike :func:`graph_inputs` this function will return
the closest nodes to outputs that do not depend on
``ancestors_to_include``. So given all the returned
variables provided there is no missing node to
compute the output and all nodes are independent
from each other.
Parameters
----------
outputs : Collection[Variable]
Variable to get conditions for
ancestors_to_include : Optional[Collection[Variable]]
Additional ancestors to assume, by default None
Returns
-------
List[Variable]
Variables required to compute ``outputs``
Examples
--------
The returned nodes marked in (parenthesis), ancestors nodes are ``c``, output nodes are ``o``
* No ancestors to include
.. code-block::
n - n - (o)
* One ancestors to include
.. code-block::
n - (c) - o
* Two ancestors to include where on depends on another, both returned
.. code-block::
(c) - (c) - o
* Additional nodes are present
.. code-block::
(c) - n - o
n - (n) -'
* Disconnected ancestors to include not returned
.. code-block::
(c) - n - o
c
* Disconnected output is present and returned
.. code-block::
(c) - (c) - o
(o)
* ancestors to include that include itself adds itself
.. code-block::
n - (c) - (o/c)
"""
# simple case, no additional ancestors to include
truncated_inputs = list()
# blockers have known independent nodes and ancestors to include
candidates = list(outputs)
if not ancestors_to_include: # None or empty
# just filter out unique variables
for node in candidates:
if node not in truncated_inputs:
truncated_inputs.append(node)
# no more actions are needed
return truncated_inputs
blockers: Set[Variable] = set(ancestors_to_include)
# enforce O(1) check for node in ancestors to include
ancestors_to_include = blockers.copy()
while candidates:
# on any new candidate
node = candidates.pop()
# check if the node is independent, never go above blockers
# blockers are independent nodes and ancestors to include
if node in ancestors_to_include:
# The case where node is in ancestors to include so we check if it depends on others
# it should be removed from the blockers to check against the rest
dependent = variable_depends_on(node, blockers - {node})
# ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs
truncated_inputs.append(node)
if dependent:
# if the ancestors to include is still dependent we need to go above,
# the search is not yet finished
# the node _has_ to have owner to be dependent
# so we do not check it
# and populate search to go above
# owner can never be None for a dependent node
candidates.extend(node.owner.inputs)
else:
# A regular node to check
dependent = variable_depends_on(node, blockers)
# all regular nodes fall to blockes
# 1. it is dependent - further search irrelevant
# 2. it is independent - the search node is inside the closure
blockers.add(node)
# if we've found an independent node and it is not in blockers so far
# it is a new indepenent node not present in ancestors to include
if not dependent:
# we've found an independent node
# do not search beyond
truncated_inputs.append(node)
else:
# populate search otherwise
# owner can never be None for a dependent node
candidates.extend(node.owner.inputs)
return truncated_inputs
def clone(
inputs: List[Variable],
outputs: List[Variable],
......
......@@ -23,6 +23,7 @@ from pytensor.graph.basic import (
io_toposort,
list_of_nodes,
orphans_between,
truncated_graph_inputs,
variable_depends_on,
vars_between,
walk,
......@@ -695,3 +696,56 @@ def test_variable_depends_on():
assert not variable_depends_on(y, [y2])
assert variable_depends_on(y, [y])
def test_truncated_graph_inputs():
"""
* No conditions
n - n - (o)
* One condition
n - (c) - o
* Two conditions where on depends on another, both returned
(c) - (c) - o
* Additional nodes are present
(c) - n - o
n - (n) -'
* Disconnected condition not returned
(c) - n - o
c
* Disconnected output is present and returned
(c) - (c) - o
(o)
* Condition on itself adds itself
n - (c) - (o/c)
"""
x = MyVariable(1)
x.name = "x"
y = MyVariable(1)
y.name = "y"
z = MyVariable(1)
z.name = "z"
x2 = MyOp(x)
x2.name = "x2"
y2 = MyOp(y, x2)
y2.name = "y2"
o = MyOp(y2)
o2 = MyOp(o)
# No conditions
assert truncated_graph_inputs([o]) == [o]
# One condition
assert truncated_graph_inputs([o2], [y2]) == [y2]
# Condition on itself adds itself
assert truncated_graph_inputs([o], [y2, o]) == [o, y2]
# Two conditions where on depends on another, both returned
assert truncated_graph_inputs([o2], [y2, o]) == [o, y2]
# Additional nodes are present
assert truncated_graph_inputs([o], [y]) == [x2, y]
# Disconnected condition
assert truncated_graph_inputs([o2], [y2, z]) == [y2]
# Disconnected output is present
assert truncated_graph_inputs([o2, z], [y2]) == [z, y2]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论