提交 af03b72f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename theano.gof.stack_search to walk

上级 3e5047ea
......@@ -19,8 +19,8 @@ from theano.gof.graph import (
list_of_nodes,
ops,
orphans,
stack_search,
variables,
walk,
)
from theano.gof.op import Op
from theano.gof.type import Type
......@@ -331,7 +331,7 @@ def test_equal_computations():
assert equal_computations(max_argmax1, max_argmax2)
def test_stack_search():
def test_walk():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
......@@ -343,15 +343,15 @@ def test_stack_search():
if r.owner:
return r.owner.inputs
res = stack_search([o2], expand, bfs=True, return_children=False)
res = walk([o2], expand, bfs=True, return_children=False)
res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2]
res = stack_search([o2], expand, bfs=False, return_children=False)
res = walk([o2], expand, bfs=False, return_children=False)
res_list = list(res)
assert res_list == [o2, o1, r2, r1, r3]
res = stack_search([o2], expand, bfs=True, return_children=True)
res = walk([o2], expand, bfs=True, return_children=True)
res_list = list(res)
assert res_list == [
(o2, [r3, o1]),
......
......@@ -661,7 +661,7 @@ class Constant(Variable):
# index is not defined, because the `owner` attribute must necessarily be None
def stack_search(
def walk(
nodes: Iterable[T],
expand: Callable[[T], Optional[Sequence[T]]],
bfs: bool = True,
......@@ -754,7 +754,7 @@ def ancestors(
if r.owner and (not blockers or r not in blockers):
return reversed(r.owner.inputs)
yield from stack_search(graphs, expand, False)
yield from walk(graphs, expand, False)
def inputs(
......@@ -807,7 +807,7 @@ def variables(
if r.owner and r not in ins:
return reversed(r.owner.inputs + r.owner.outputs)
yield from stack_search(outs, expand)
yield from walk(outs, expand)
def orphans(
......@@ -1027,7 +1027,7 @@ def general_toposort(
raise ValueError("deps_cache cannot be None")
search_res: List[T, Optional[List[T]]] = list(
stack_search(outputs, compute_deps_cache, bfs=False, return_children=True)
walk(outputs, compute_deps_cache, bfs=False, return_children=True)
)
_clients: Dict[T, List[T]] = {}
......@@ -1364,7 +1364,7 @@ def list_of_nodes(
"""
return list(
stack_search(
walk(
[o.owner for o in outputs],
lambda o: [
inp.owner
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论