提交 880a57aa authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: Ricardo Vieira

remove list_of_nodes in favor of similar applys_between

上级 8c157a25
......@@ -1789,38 +1789,6 @@ def view_roots(node: Variable) -> list[Variable]:
return [node]
def list_of_nodes(
inputs: Collection[Variable], outputs: Iterable[Variable]
) -> list[Apply]:
r"""Return the `Apply` nodes of the graph between `inputs` and `outputs`.
Parameters
----------
inputs : list of Variable
Input `Variable`\s.
outputs : list of Variable
Output `Variable`\s.
"""
def expand(o: Apply) -> list[Apply]:
return [
inp.owner
for inp in o.inputs
if inp.owner and not any(i in inp.owner.outputs for i in inputs)
]
return list(
cast(
Iterable[Apply],
walk(
[o.owner for o in outputs if o.owner],
expand,
),
)
)
def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool:
"""Determine if any `depends_on` is in the graph given by ``apply``.
......
......@@ -24,7 +24,7 @@ import pytensor
from pytensor import printing
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
from pytensor.graph.basic import Apply, Constant, Variable, applys_between, clone
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.graph.rewriting.basic import MergeOptimizer
......@@ -4125,7 +4125,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
def prepare_node(self, node, storage_map, compute_map, impl):
if impl not in self.prepare_node_called:
for n in list_of_nodes(self.inputs, self.outputs):
for n in applys_between(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl)
self.prepare_node_called.add(impl)
......
......@@ -23,7 +23,6 @@ from pytensor.graph.basic import (
get_var_by_name,
graph_inputs,
io_toposort,
list_of_nodes,
orphans_between,
truncated_graph_inputs,
variable_depends_on,
......@@ -567,17 +566,6 @@ def test_ops():
assert res_list == [o3.owner, o2.owner, o1.owner]
def test_list_of_nodes():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
res = list_of_nodes([r1, r2], [o2])
assert res == [o2.owner, o1.owner]
def test_apply_depends_on():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论