提交 21d723bd authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in truncated_graph_inputs

It could return duplicated truncated inputs before the changes, as well as return wrong outputs based on the nodes input order
上级 12ca8bd9
...@@ -1056,6 +1056,7 @@ def truncated_graph_inputs( ...@@ -1056,6 +1056,7 @@ def truncated_graph_inputs(
truncated_inputs.append(node) truncated_inputs.append(node)
# no more actions are needed # no more actions are needed
return truncated_inputs return truncated_inputs
blockers: Set[Variable] = set(ancestors_to_include) blockers: Set[Variable] = set(ancestors_to_include)
# enforce O(1) check for node in ancestors to include # enforce O(1) check for node in ancestors to include
ancestors_to_include = blockers.copy() ancestors_to_include = blockers.copy()
...@@ -1063,40 +1064,40 @@ def truncated_graph_inputs( ...@@ -1063,40 +1064,40 @@ def truncated_graph_inputs(
while candidates: while candidates:
# on any new candidate # on any new candidate
node = candidates.pop() node = candidates.pop()
# check if the node is independent, never go above blockers
# There was a repeated reference to this node, we have already investigated it
if node in truncated_inputs:
continue
# check if the node is independent, never go above blockers;
# blockers are independent nodes and ancestors to include # blockers are independent nodes and ancestors to include
if node in ancestors_to_include: if node in ancestors_to_include:
# The case where node is in ancestors to include so we check if it depends on others # The case where node is in ancestors to include so we check if it depends on others
# it should be removed from the blockers to check against the rest # it should be removed from the blockers to check against the rest
dependent = variable_depends_on(node, blockers - {node}) dependent = variable_depends_on(node, ancestors_to_include - {node})
# ancestors to include that are present in the graph (not disconnected) # ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs # should be added to truncated_inputs
truncated_inputs.append(node) truncated_inputs.append(node)
if dependent: if dependent:
# if the ancestors to include is still dependent we need to go above, # if the ancestors to include is still dependent we need to go above, the search is not yet finished
# the search is not yet finished
# the node _has_ to have owner to be dependent
# so we do not check it
# and populate search to go above
# owner can never be None for a dependent node # owner can never be None for a dependent node
candidates.extend(node.owner.inputs) candidates.extend(node.owner.inputs)
else: else:
# A regular node to check # A regular node to check
dependent = variable_depends_on(node, blockers) dependent = variable_depends_on(node, blockers)
# all regular nodes fall to blockes # all regular nodes fall to blockers
# 1. it is dependent - further search irrelevant # 1. it is dependent - further search irrelevant
# 2. it is independent - the search node is inside the closure # 2. it is independent - the search node is inside the closure
blockers.add(node) blockers.add(node)
# if we've found an independent node and it is not in blockers so far # if we've found an independent node and it is not in blockers so far
# it is a new indepenent node not present in ancestors to include # it is a new independent node not present in ancestors to include
if not dependent: if dependent:
# we've found an independent node # populate search if it's not an independent node
# do not search beyond
truncated_inputs.append(node)
else:
# populate search otherwise
# owner can never be None for a dependent node # owner can never be None for a dependent node
candidates.extend(node.owner.inputs) candidates.extend(node.owner.inputs)
else:
# otherwise, do not search beyond
truncated_inputs.append(node)
return truncated_inputs return truncated_inputs
......
...@@ -697,55 +697,95 @@ def test_variable_depends_on(): ...@@ -697,55 +697,95 @@ def test_variable_depends_on():
assert variable_depends_on(y, [y]) assert variable_depends_on(y, [y])
def test_truncated_graph_inputs(): class TestTruncatedGraphInputs:
""" def test_basic(self):
* No conditions """
n - n - (o) * No conditions
n - n - (o)
* One condition
n - (c) - o * One condition
n - (c) - o
* Two conditions where on depends on another, both returned
(c) - (c) - o * Two conditions where on depends on another, both returned
(c) - (c) - o
* Additional nodes are present
(c) - n - o * Additional nodes are present
n - (n) -' (c) - n - o
n - (n) -'
* Disconnected condition not returned
(c) - n - o
c
* Disconnected output is present and returned
(c) - (c) - o
(o)
* Condition on itself adds itself
n - (c) - (o/c)
"""
x = MyVariable(1)
x.name = "x"
y = MyVariable(1)
y.name = "y"
z = MyVariable(1)
z.name = "z"
x2 = MyOp(x)
x2.name = "x2"
y2 = MyOp(y, x2)
y2.name = "y2"
o = MyOp(y2)
o2 = MyOp(o)
# No conditions
assert truncated_graph_inputs([o]) == [o]
# One condition
assert truncated_graph_inputs([o2], [y2]) == [y2]
# Condition on itself adds itself
assert truncated_graph_inputs([o], [y2, o]) == [o, y2]
# Two conditions where on depends on another, both returned
assert truncated_graph_inputs([o2], [y2, o]) == [o, y2]
# Additional nodes are present
assert truncated_graph_inputs([o], [y]) == [x2, y]
# Disconnected condition
assert truncated_graph_inputs([o2], [y2, z]) == [y2]
# Disconnected output is present
assert truncated_graph_inputs([o2, z], [y2]) == [z, y2]
def test_repeated_input(self):
"""Test that truncated_graph_inputs does not return repeated inputs."""
x = MyVariable(1)
x.name = "x"
y = MyVariable(1)
y.name = "y"
trunc_inp1 = MyOp(x, y)
trunc_inp1.name = "trunc_inp1"
trunc_inp2 = MyOp(x, y)
trunc_inp2.name = "trunc_inp2"
o = MyOp(trunc_inp1, trunc_inp1, trunc_inp2, trunc_inp2)
o.name = "o"
assert truncated_graph_inputs([o], [trunc_inp1]) == [trunc_inp2, trunc_inp1]
def test_repeated_nested_input(self):
"""Test that truncated_graph_inputs does not return repeated inputs."""
x = MyVariable(1)
x.name = "x"
y = MyVariable(1)
y.name = "y"
trunc_inp = MyOp(x, y)
trunc_inp.name = "trunc_inp"
o1 = MyOp(trunc_inp, trunc_inp, x, x)
o1.name = "o1"
* Disconnected condition not returned assert truncated_graph_inputs([o1], [trunc_inp]) == [x, trunc_inp]
(c) - n - o
c
* Disconnected output is present and returned # Reverse order of inputs
(c) - (c) - o o2 = MyOp(x, x, trunc_inp, trunc_inp)
(o) o2.name = "o2"
* Condition on itself adds itself assert truncated_graph_inputs([o2], [trunc_inp]) == [trunc_inp, x]
n - (c) - (o/c)
"""
x = MyVariable(1)
x.name = "x"
y = MyVariable(1)
y.name = "y"
z = MyVariable(1)
z.name = "z"
x2 = MyOp(x)
x2.name = "x2"
y2 = MyOp(y, x2)
y2.name = "y2"
o = MyOp(y2)
o2 = MyOp(o)
# No conditions
assert truncated_graph_inputs([o]) == [o]
# One condition
assert truncated_graph_inputs([o2], [y2]) == [y2]
# Condition on itself adds itself
assert truncated_graph_inputs([o], [y2, o]) == [o, y2]
# Two conditions where on depends on another, both returned
assert truncated_graph_inputs([o2], [y2, o]) == [o, y2]
# Additional nodes are present
assert truncated_graph_inputs([o], [y]) == [x2, y]
# Disconnected condition
assert truncated_graph_inputs([o2], [y2, z]) == [y2]
# Disconnected output is present
assert truncated_graph_inputs([o2, z], [y2]) == [z, y2]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论