提交 76a8e1c0 authored 作者: lamblin's avatar lamblin

Merge pull request #1325 from f0k/fix-scanutils-traverse

scan_utils.py: Make traversal routines remember visited nodes
......@@ -105,7 +105,7 @@ class until(object):
assert self.condition.ndim == 0
def traverse(out, x, x_copy, d):
def traverse(out, x, x_copy, d, visited=None):
''' Function used by scan to parse the tree and figure out which nodes
it needs to replace. There are two options :
1) x and x_copy or on host, then you would replace x with x_copy
......@@ -114,6 +114,15 @@ def traverse(out, x, x_copy, d):
This happens because initially shared variables are on GPU .. which is
fine for the main computational graph but confuses things a bit for the
inner graph of scan '''
# ``visited`` is a set of nodes that are already known and don't need to be
# checked again, speeding up the traversal of multiply-connected graphs.
# if a ``visited`` set is given, it will be updated in-place so the callee
# knows which nodes we have seen.
if visited is None:
visited = set()
if out in visited:
return d
visited.add(out)
import theano.sandbox.cuda as cuda
if out == x:
d[out] = cuda.gpu_from_host(x_copy)
......@@ -127,7 +136,7 @@ def traverse(out, x, x_copy, d):
return d
else:
for inp in out.owner.inputs:
d = traverse(inp, x, x_copy, d)
d = traverse(inp, x, x_copy, d, visited)
return d
......@@ -988,7 +997,13 @@ def forced_replace(out, x, y):
if out is None:
return None
def traverse(graph, x):
# ``visited`` is a set of nodes that are already known and don't need to be
# checked again, speeding up the traversal of multiply-connected graphs.
visited = set()
def local_traverse(graph, x):
if graph in visited:
return []
visited.add(graph)
if equal_computations([graph], [x]):
return [graph]
elif not graph.owner:
......@@ -996,7 +1011,7 @@ def forced_replace(out, x, y):
else:
rval = []
for inp in graph.owner.inputs:
rval += traverse(inp, x)
rval += local_traverse(inp, x)
return rval
to_replace = traverse(out, x)
to_replace = local_traverse(out, x)
return clone(out, replace=OrderedDict((v, y) for v in to_replace))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论