提交 8ad33179 authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Ricardo Vieira

refactor is_in_ancestors to support multiple inputs

上级 38731adb
...@@ -1568,15 +1568,15 @@ def list_of_nodes( ...@@ -1568,15 +1568,15 @@ def list_of_nodes(
) )
def is_in_ancestors(l_apply: Apply, f_apply: Apply) -> bool: def apply_depends_on(apply: Apply, depends_on: Union[Apply, Collection[Apply]]) -> bool:
"""Determine if `f_apply` is in the graph given by `l_apply`. """Determine if any `depends_on` is in the graph given by ``apply``.
Parameters Parameters
---------- ----------
l_apply : Apply apply : Apply
The node to walk. The Apply node to check.
f_apply : Apply depends_on : Union[Apply, Collection[Apply]]
The node to find in `l_apply`. Apply nodes to check dependency on
Returns Returns
------- -------
...@@ -1584,14 +1584,18 @@ def is_in_ancestors(l_apply: Apply, f_apply: Apply) -> bool: ...@@ -1584,14 +1584,18 @@ def is_in_ancestors(l_apply: Apply, f_apply: Apply) -> bool:
""" """
computed = set() computed = set()
todo = [l_apply] todo = [apply]
if not isinstance(depends_on, Collection):
depends_on = {depends_on}
else:
depends_on = set(depends_on)
while todo: while todo:
cur = todo.pop() cur = todo.pop()
if cur.outputs[0] in computed: if cur.outputs[0] in computed:
continue continue
if all(i in computed or i.owner is None for i in cur.inputs): if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs) computed.update(cur.outputs)
if cur is f_apply: if cur in depends_on:
return True return True
else: else:
todo.append(cur) todo.append(cur)
......
...@@ -20,7 +20,7 @@ import pytensor.tensor as at ...@@ -20,7 +20,7 @@ import pytensor.tensor as at
from pytensor import as_symbolic from pytensor import as_symbolic
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, is_in_ancestors from pytensor.graph.basic import Apply, Variable, apply_depends_on
from pytensor.graph.op import _NoPythonOp from pytensor.graph.op import _NoPythonOp
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
...@@ -604,7 +604,7 @@ class CondMerge(GraphRewriter): ...@@ -604,7 +604,7 @@ class CondMerge(GraphRewriter):
return False return False
merging_node = cond_nodes[0] merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]: for proposal in cond_nodes[1:]:
if proposal.inputs[0] == merging_node.inputs[0] and not is_in_ancestors( if proposal.inputs[0] == merging_node.inputs[0] and not apply_depends_on(
proposal, merging_node proposal, merging_node
): ):
# Create a list of replacements for proposal # Create a list of replacements for proposal
...@@ -704,8 +704,8 @@ def cond_merge_random_op(fgraph, main_node): ...@@ -704,8 +704,8 @@ def cond_merge_random_op(fgraph, main_node):
for proposal in cond_nodes[1:]: for proposal in cond_nodes[1:]:
if ( if (
proposal.inputs[0] == merging_node.inputs[0] proposal.inputs[0] == merging_node.inputs[0]
and not is_in_ancestors(proposal, merging_node) and not apply_depends_on(proposal, merging_node)
and not is_in_ancestors(merging_node, proposal) and not apply_depends_on(merging_node, proposal)
): ):
# Create a list of replacements for proposal # Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][: merging_node.op.n_outs] mn_ts = merging_node.inputs[1:][: merging_node.op.n_outs]
......
...@@ -18,10 +18,10 @@ from pytensor.graph.basic import ( ...@@ -18,10 +18,10 @@ from pytensor.graph.basic import (
Apply, Apply,
Constant, Constant,
Variable, Variable,
apply_depends_on,
equal_computations, equal_computations,
graph_inputs, graph_inputs,
io_toposort, io_toposort,
is_in_ancestors,
) )
from pytensor.graph.destroyhandler import DestroyHandler from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import ReplaceValidate from pytensor.graph.features import ReplaceValidate
...@@ -1642,7 +1642,7 @@ def save_mem_new_scan(fgraph, node): ...@@ -1642,7 +1642,7 @@ def save_mem_new_scan(fgraph, node):
old_new += [(o, new_outs[nw_pos])] old_new += [(o, new_outs[nw_pos])]
# Check if the new outputs depend on the old scan node # Check if the new outputs depend on the old scan node
old_scan_is_used = [ old_scan_is_used = [
is_in_ancestors(new.owner, node) for old, new in old_new apply_depends_on(new.owner, node) for old, new in old_new
] ]
if any(old_scan_is_used): if any(old_scan_is_used):
return False return False
...@@ -1877,7 +1877,7 @@ class ScanMerge(GraphRewriter): ...@@ -1877,7 +1877,7 @@ class ScanMerge(GraphRewriter):
# Check to see if it is an input of a different node # Check to see if it is an input of a different node
for nd in set_nodes: for nd in set_nodes:
if is_in_ancestors(node, nd) or is_in_ancestors(nd, node): if apply_depends_on(node, nd) or apply_depends_on(nd, node):
return False return False
if not node.op.info.as_while: if not node.op.info.as_while:
......
...@@ -11,6 +11,7 @@ from pytensor.graph.basic import ( ...@@ -11,6 +11,7 @@ from pytensor.graph.basic import (
NominalVariable, NominalVariable,
Variable, Variable,
ancestors, ancestors,
apply_depends_on,
applys_between, applys_between,
as_string, as_string,
clone, clone,
...@@ -20,7 +21,6 @@ from pytensor.graph.basic import ( ...@@ -20,7 +21,6 @@ from pytensor.graph.basic import (
get_var_by_name, get_var_by_name,
graph_inputs, graph_inputs,
io_toposort, io_toposort,
is_in_ancestors,
list_of_nodes, list_of_nodes,
orphans_between, orphans_between,
vars_between, vars_between,
...@@ -491,15 +491,19 @@ def test_list_of_nodes(): ...@@ -491,15 +491,19 @@ def test_list_of_nodes():
assert res == [o2.owner, o1.owner] assert res == [o2.owner, o1.owner]
def test_is_in_ancestors(): def test_apply_depends_on():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2) o1 = MyOp(r1, r2)
o1.name = "o1" o1.name = "o1"
o2 = MyOp(r3, o1) o2 = MyOp(r1, o1)
o2.name = "o2" o2.name = "o2"
o3 = MyOp(r3, o1, o2)
o3.name = "o3"
assert is_in_ancestors(o2.owner, o1.owner) assert apply_depends_on(o2.owner, o1.owner)
assert apply_depends_on(o2.owner, o2.owner)
assert apply_depends_on(o3.owner, [o1.owner, o2.owner])
@pytest.mark.xfail(reason="Not implemented") @pytest.mark.xfail(reason="Not implemented")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论