Unverified 提交 044910bf authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Add helper `explicit_graph_inputs` (#712)

上级 27bd9aaf
...@@ -936,6 +936,55 @@ def graph_inputs( ...@@ -936,6 +936,55 @@ def graph_inputs(
yield from (r for r in ancestors(graphs, blockers) if r.owner is None) yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
def explicit_graph_inputs(
graph: Variable | Iterable[Variable],
) -> Generator[Variable, None, None]:
"""
Get the root variables needed as inputs to a function that computes `graph`
Parameters
----------
graph : TensorVariable
Output `Variable` instances for which to search backward through
owners.
Returns
-------
iterable
Generator of root Variables (without owner) needed to compile a function that evaluates `graphs`.
Examples
--------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
from pytensor.graph.basic import explicit_graph_inputs
x = pt.vector('x')
y = pt.constant(2)
z = pt.mul(x*y)
inputs = list(explicit_graph_inputs(z))
f = pytensor.function(inputs, z)
eval = f([1, 2, 3])
print(eval)
# [2. 4. 6.]
"""
from pytensor.compile.sharedvalue import SharedVariable
if isinstance(graph, Variable):
graph = [graph]
return (
v
for v in graph_inputs(graph)
if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable)
)
def vars_between( def vars_between(
ins: Collection[Variable], outs: Iterable[Variable] ins: Collection[Variable], outs: Iterable[Variable]
) -> Generator[Variable, None, None]: ) -> Generator[Variable, None, None]:
......
...@@ -18,6 +18,7 @@ from pytensor.graph.basic import ( ...@@ -18,6 +18,7 @@ from pytensor.graph.basic import (
clone, clone,
clone_get_equiv, clone_get_equiv,
equal_computations, equal_computations,
explicit_graph_inputs,
general_toposort, general_toposort,
get_var_by_name, get_var_by_name,
graph_inputs, graph_inputs,
...@@ -522,6 +523,20 @@ def test_graph_inputs(): ...@@ -522,6 +523,20 @@ def test_graph_inputs():
assert res_list == [r3, r1, r2] assert res_list == [r3, r1, r2]
def test_explicit_graph_inputs():
x = pt.fscalar()
y = pt.constant(2)
z = shared(1)
a = pt.sum(x + y + z)
b = pt.true_div(x, y)
res = list(explicit_graph_inputs([a]))
res1 = list(explicit_graph_inputs(b))
assert res == [x]
assert res1 == [x]
def test_variables_and_orphans(): def test_variables_and_orphans():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2) o1 = MyOp(r1, r2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论