提交 996d737e authored 作者: Frederic Bastien's avatar Frederic Bastien

speed up scan forced_replace by not recursing and making less objects.

上级 dde50054
...@@ -1329,6 +1329,9 @@ def forced_replace(out, x, y): ...@@ -1329,6 +1329,9 @@ def forced_replace(out, x, y):
x := sigmoid(wu) x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y) forced_replace(out, x, y) := y*(1-y)
Note
----
When it find a match, it don't continue on the corresponding inputs.
""" """
if out is None: if out is None:
return None return None
...@@ -1336,18 +1339,18 @@ def forced_replace(out, x, y): ...@@ -1336,18 +1339,18 @@ def forced_replace(out, x, y):
# ``visited`` is a set of nodes that are already known and don't need to be # ``visited`` is a set of nodes that are already known and don't need to be
# checked again, speeding up the traversal of multiply-connected graphs. # checked again, speeding up the traversal of multiply-connected graphs.
visited = set() visited = set()
def local_traverse(graph, x): from collections import deque
q = deque()
q.append(out)
to_replace = []
while q:
graph = q.popleft()
if graph in visited: if graph in visited:
return [] continue
visited.add(graph) visited.add(graph)
if equal_computations([graph], [x]): if equal_computations([graph], [x]):
return [graph] to_replace.append((graph, y))
elif not graph.owner: elif graph.owner:
return [] q.extendleft(graph.owner.inputs)
else:
rval = [] return clone(out, replace=to_replace)
for inp in graph.owner.inputs:
rval += local_traverse(inp, x)
return rval
to_replace = local_traverse(out, x)
return clone(out, replace=OrderedDict((v, y) for v in to_replace))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论