提交 c6dae89f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Move view_roots to the only file where it is used

上级 d8beb7e4
...@@ -1787,25 +1787,6 @@ def as_string( ...@@ -1787,25 +1787,6 @@ def as_string(
return [describe(output) for output in outputs] return [describe(output) for output in outputs]
def view_roots(node: Variable) -> list[Variable]:
"""Return the leaves from a search through consecutive view-maps."""
owner = node.owner
if owner is not None:
try:
vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()}
except AttributeError:
return [node]
if node in vars_to_views:
answer = []
for i in vars_to_views[node]:
answer += view_roots(owner.inputs[i])
return answer
else:
return [node]
else:
return [node]
def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool: def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool:
"""Determine if any `depends_on` is in the graph given by ``apply``. """Determine if any `depends_on` is in the graph given by ``apply``.
......
...@@ -85,7 +85,7 @@ from pathlib import Path ...@@ -85,7 +85,7 @@ from pathlib import Path
import numpy as np import numpy as np
from scipy.linalg import get_blas_funcs from scipy.linalg import get_blas_funcs
from pytensor.graph import vectorize_graph from pytensor.graph import Variable, vectorize_graph
from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.npy_2_compat import normalize_axis_tuple
...@@ -97,7 +97,7 @@ except ImportError: ...@@ -97,7 +97,7 @@ except ImportError:
import pytensor.scalar import pytensor.scalar
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, view_roots from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
...@@ -114,6 +114,25 @@ from pytensor.tensor.type import DenseTensorType, tensor ...@@ -114,6 +114,25 @@ from pytensor.tensor.type import DenseTensorType, tensor
_logger = logging.getLogger("pytensor.tensor.blas") _logger = logging.getLogger("pytensor.tensor.blas")
def view_roots(node: Variable) -> list[Variable]:
"""Return the leaves from a search through consecutive view-maps."""
owner = node.owner
if owner is not None:
try:
vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()}
except AttributeError:
return [node]
if node in vars_to_views:
answer = []
for i in vars_to_views[node]:
answer += view_roots(owner.inputs[i])
return answer
else:
return [node]
else:
return [node]
def must_initialize_y_gemv(): def must_initialize_y_gemv():
# Check whether Scipy GEMV could output nan if y in not initialized # Check whether Scipy GEMV could output nan if y in not initialized
from scipy.linalg.blas import get_blas_funcs from scipy.linalg.blas import get_blas_funcs
......
...@@ -589,11 +589,6 @@ def test_io_connection_pattern(): ...@@ -589,11 +589,6 @@ def test_io_connection_pattern():
raise AssertionError() raise AssertionError()
@pytest.mark.xfail(reason="Not implemented")
def test_view_roots():
raise AssertionError()
def test_get_var_by_name(): def test_get_var_by_name():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2) o1 = MyOp(r1, r2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论