提交 c6dae89f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Move view_roots to the only file where it is used

上级 d8beb7e4
......@@ -1787,25 +1787,6 @@ def as_string(
return [describe(output) for output in outputs]
def view_roots(node: Variable) -> list[Variable]:
"""Return the leaves from a search through consecutive view-maps."""
owner = node.owner
if owner is not None:
try:
vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()}
except AttributeError:
return [node]
if node in vars_to_views:
answer = []
for i in vars_to_views[node]:
answer += view_roots(owner.inputs[i])
return answer
else:
return [node]
else:
return [node]
def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool:
"""Determine if any `depends_on` is in the graph given by ``apply``.
......
......@@ -85,7 +85,7 @@ from pathlib import Path
import numpy as np
from scipy.linalg import get_blas_funcs
from pytensor.graph import vectorize_graph
from pytensor.graph import Variable, vectorize_graph
from pytensor.npy_2_compat import normalize_axis_tuple
......@@ -97,7 +97,7 @@ except ImportError:
import pytensor.scalar
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, view_roots
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from pytensor.link.c.op import COp
......@@ -114,6 +114,25 @@ from pytensor.tensor.type import DenseTensorType, tensor
_logger = logging.getLogger("pytensor.tensor.blas")
def view_roots(node: Variable) -> list[Variable]:
"""Return the leaves from a search through consecutive view-maps."""
owner = node.owner
if owner is not None:
try:
vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()}
except AttributeError:
return [node]
if node in vars_to_views:
answer = []
for i in vars_to_views[node]:
answer += view_roots(owner.inputs[i])
return answer
else:
return [node]
else:
return [node]
def must_initialize_y_gemv():
# Check whether Scipy GEMV could output nan if y in not initialized
from scipy.linalg.blas import get_blas_funcs
......
......@@ -589,11 +589,6 @@ def test_io_connection_pattern():
raise AssertionError()
@pytest.mark.xfail(reason="Not implemented")
def test_view_roots():
raise AssertionError()
def test_get_var_by_name():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论