提交 9f1aaacb authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5522 from nouiz/scan_opt

Speed up one scan opt
...@@ -1818,7 +1818,12 @@ class ScanMerge(gof.Optimizer): ...@@ -1818,7 +1818,12 @@ class ScanMerge(gof.Optimizer):
""" """
rep = set_nodes[0] rep = set_nodes[0]
if rep.op.as_while != node.op.as_while: if (rep.op.as_while != node.op.as_while or
len(rep.inputs) != len(node.inputs) or
len(rep.outputs) != len(node.outputs) or
node.op.truncate_gradient != rep.op.truncate_gradient or
node.op.mode != rep.op.mode or
rep.op.as_while != node.op.as_while):
return False return False
nsteps = node.inputs[0] nsteps = node.inputs[0]
...@@ -1834,22 +1839,18 @@ class ScanMerge(gof.Optimizer): ...@@ -1834,22 +1839,18 @@ class ScanMerge(gof.Optimizer):
pass pass
# Check to see if it is an input of a different node # Check to see if it is an input of a different node
can_add = True
for nd in set_nodes: for nd in set_nodes:
if find_up(node, nd) or find_up(nd, node): if find_up(node, nd) or find_up(nd, node):
can_add = False return False
can_add = can_add and (node.op.truncate_gradient ==
rep.op.truncate_gradient)
can_add = can_add and (node.op.mode == rep.op.mode)
if not node.op.as_while: if not node.op.as_while:
return nsteps == rep_nsteps and can_add return nsteps == rep_nsteps
cond = node.op.outputs[-1] cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1] rep_cond = rep.op.outputs[-1]
same_cond = scan_utils.equal_computations([cond], [rep_cond], same_cond = scan_utils.equal_computations([cond], [rep_cond],
node.op.inputs, node.op.inputs,
rep.op.inputs) rep.op.inputs)
return same_cond and (nsteps == rep_nsteps) and can_add return same_cond and (nsteps == rep_nsteps)
def apply(self, fgraph): def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort # Collect all scan nodes ordered according to toposort
......
...@@ -1329,6 +1329,9 @@ def forced_replace(out, x, y): ...@@ -1329,6 +1329,9 @@ def forced_replace(out, x, y):
x := sigmoid(wu) x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y) forced_replace(out, x, y) := y*(1-y)
Note
----
When it find a match, it don't continue on the corresponding inputs.
""" """
if out is None: if out is None:
return None return None
...@@ -1336,18 +1339,18 @@ def forced_replace(out, x, y): ...@@ -1336,18 +1339,18 @@ def forced_replace(out, x, y):
# ``visited`` is a set of nodes that are already known and don't need to be # ``visited`` is a set of nodes that are already known and don't need to be
# checked again, speeding up the traversal of multiply-connected graphs. # checked again, speeding up the traversal of multiply-connected graphs.
visited = set() visited = set()
def local_traverse(graph, x): from collections import deque
q = deque()
q.append(out)
to_replace = []
while q:
graph = q.popleft()
if graph in visited: if graph in visited:
return [] continue
visited.add(graph) visited.add(graph)
if equal_computations([graph], [x]): if equal_computations([graph], [x]):
return [graph] to_replace.append((graph, y))
elif not graph.owner: elif graph.owner:
return [] q.extendleft(graph.owner.inputs)
else:
rval = [] return clone(out, replace=to_replace)
for inp in graph.owner.inputs:
rval += local_traverse(inp, x)
return rval
to_replace = local_traverse(out, x)
return clone(out, replace=OrderedDict((v, y) for v in to_replace))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论