提交 996d737e authored 作者: Frederic Bastien's avatar Frederic Bastien

speed up scan forced_replace by not recursing and making less objects.

上级 dde50054
......@@ -1329,6 +1329,9 @@ def forced_replace(out, x, y):
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
Note
----
When it find a match, it don't continue on the corresponding inputs.
"""
if out is None:
return None
......@@ -1336,18 +1339,18 @@ def forced_replace(out, x, y):
# ``visited`` is a set of nodes that are already known and don't need to be
# checked again, speeding up the traversal of multiply-connected graphs.
visited = set()
def local_traverse(graph, x):
from collections import deque
q = deque()
q.append(out)
to_replace = []
while q:
graph = q.popleft()
if graph in visited:
return []
continue
visited.add(graph)
if equal_computations([graph], [x]):
return [graph]
elif not graph.owner:
return []
else:
rval = []
for inp in graph.owner.inputs:
rval += local_traverse(inp, x)
return rval
to_replace = local_traverse(out, x)
return clone(out, replace=OrderedDict((v, y) for v in to_replace))
to_replace.append((graph, y))
elif graph.owner:
q.extendleft(graph.owner.inputs)
return clone(out, replace=to_replace)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论