提交 dde50054 authored 作者: Frederic Bastien's avatar Frederic Bastien

Move all cheap condition at the start to do less frequently the slow conditino.

上级 5bb6abe5
...@@ -1818,7 +1818,12 @@ class ScanMerge(gof.Optimizer): ...@@ -1818,7 +1818,12 @@ class ScanMerge(gof.Optimizer):
""" """
rep = set_nodes[0] rep = set_nodes[0]
if rep.op.as_while != node.op.as_while: if (rep.op.as_while != node.op.as_while or
len(rep.inputs) != len(node.inputs) or
len(rep.outputs) != len(node.outputs) or
node.op.truncate_gradient != rep.op.truncate_gradient or
node.op.mode != rep.op.mode or
rep.op.as_while != node.op.as_while):
return False return False
nsteps = node.inputs[0] nsteps = node.inputs[0]
...@@ -1834,22 +1839,18 @@ class ScanMerge(gof.Optimizer): ...@@ -1834,22 +1839,18 @@ class ScanMerge(gof.Optimizer):
pass pass
# Check to see if it is an input of a different node # Check to see if it is an input of a different node
can_add = True
for nd in set_nodes: for nd in set_nodes:
if find_up(node, nd) or find_up(nd, node): if find_up(node, nd) or find_up(nd, node):
can_add = False return False
can_add = can_add and (node.op.truncate_gradient ==
rep.op.truncate_gradient)
can_add = can_add and (node.op.mode == rep.op.mode)
if not node.op.as_while: if not node.op.as_while:
return nsteps == rep_nsteps and can_add return nsteps == rep_nsteps
cond = node.op.outputs[-1] cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1] rep_cond = rep.op.outputs[-1]
same_cond = scan_utils.equal_computations([cond], [rep_cond], same_cond = scan_utils.equal_computations([cond], [rep_cond],
node.op.inputs, node.op.inputs,
rep.op.inputs) rep.op.inputs)
return same_cond and (nsteps == rep_nsteps) and can_add return same_cond and (nsteps == rep_nsteps)
def apply(self, fgraph): def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort # Collect all scan nodes ordered according to toposort
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论