提交 af03b72f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename theano.gof.stack_search to walk

上级 3e5047ea
...@@ -19,8 +19,8 @@ from theano.gof.graph import ( ...@@ -19,8 +19,8 @@ from theano.gof.graph import (
list_of_nodes, list_of_nodes,
ops, ops,
orphans, orphans,
stack_search,
variables, variables,
walk,
) )
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
...@@ -331,7 +331,7 @@ def test_equal_computations(): ...@@ -331,7 +331,7 @@ def test_equal_computations():
assert equal_computations(max_argmax1, max_argmax2) assert equal_computations(max_argmax1, max_argmax2)
def test_stack_search(): def test_walk():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2) o1 = MyOp(r1, r2)
...@@ -343,15 +343,15 @@ def test_stack_search(): ...@@ -343,15 +343,15 @@ def test_stack_search():
if r.owner: if r.owner:
return r.owner.inputs return r.owner.inputs
res = stack_search([o2], expand, bfs=True, return_children=False) res = walk([o2], expand, bfs=True, return_children=False)
res_list = list(res) res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2] assert res_list == [o2, r3, o1, r1, r2]
res = stack_search([o2], expand, bfs=False, return_children=False) res = walk([o2], expand, bfs=False, return_children=False)
res_list = list(res) res_list = list(res)
assert res_list == [o2, o1, r2, r1, r3] assert res_list == [o2, o1, r2, r1, r3]
res = stack_search([o2], expand, bfs=True, return_children=True) res = walk([o2], expand, bfs=True, return_children=True)
res_list = list(res) res_list = list(res)
assert res_list == [ assert res_list == [
(o2, [r3, o1]), (o2, [r3, o1]),
......
...@@ -661,7 +661,7 @@ class Constant(Variable): ...@@ -661,7 +661,7 @@ class Constant(Variable):
# index is not defined, because the `owner` attribute must necessarily be None # index is not defined, because the `owner` attribute must necessarily be None
def stack_search( def walk(
nodes: Iterable[T], nodes: Iterable[T],
expand: Callable[[T], Optional[Sequence[T]]], expand: Callable[[T], Optional[Sequence[T]]],
bfs: bool = True, bfs: bool = True,
...@@ -754,7 +754,7 @@ def ancestors( ...@@ -754,7 +754,7 @@ def ancestors(
if r.owner and (not blockers or r not in blockers): if r.owner and (not blockers or r not in blockers):
return reversed(r.owner.inputs) return reversed(r.owner.inputs)
yield from stack_search(graphs, expand, False) yield from walk(graphs, expand, False)
def inputs( def inputs(
...@@ -807,7 +807,7 @@ def variables( ...@@ -807,7 +807,7 @@ def variables(
if r.owner and r not in ins: if r.owner and r not in ins:
return reversed(r.owner.inputs + r.owner.outputs) return reversed(r.owner.inputs + r.owner.outputs)
yield from stack_search(outs, expand) yield from walk(outs, expand)
def orphans( def orphans(
...@@ -1027,7 +1027,7 @@ def general_toposort( ...@@ -1027,7 +1027,7 @@ def general_toposort(
raise ValueError("deps_cache cannot be None") raise ValueError("deps_cache cannot be None")
search_res: List[T, Optional[List[T]]] = list( search_res: List[T, Optional[List[T]]] = list(
stack_search(outputs, compute_deps_cache, bfs=False, return_children=True) walk(outputs, compute_deps_cache, bfs=False, return_children=True)
) )
_clients: Dict[T, List[T]] = {} _clients: Dict[T, List[T]] = {}
...@@ -1364,7 +1364,7 @@ def list_of_nodes( ...@@ -1364,7 +1364,7 @@ def list_of_nodes(
""" """
return list( return list(
stack_search( walk(
[o.owner for o in outputs], [o.owner for o in outputs],
lambda o: [ lambda o: [
inp.owner inp.owner
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论