提交 9d29cb8e authored 作者: Frederic Bastien's avatar Frederic Bastien

Refactor code to do check test first. Save 1.5s on ~3s.

上级 dd4825e6
...@@ -1829,19 +1829,21 @@ class ScanMerge(gof.Optimizer): ...@@ -1829,19 +1829,21 @@ class ScanMerge(gof.Optimizer):
except tensor.NotScalarConstantError: except tensor.NotScalarConstantError:
pass pass
if nsteps != rep_nsteps:
return False
# Check to see if it is an input of a different node # Check to see if it is an input of a different node
for nd in set_nodes: for nd in set_nodes:
if gof.graph.is_in_ancestors(node, nd) or gof.graph.is_in_ancestors(nd, node): if gof.graph.is_in_ancestors(node, nd) or gof.graph.is_in_ancestors(nd, node):
return False return False
if not node.op.as_while: if not node.op.as_while:
return nsteps == rep_nsteps return True
cond = node.op.outputs[-1] cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1] rep_cond = rep.op.outputs[-1]
same_cond = scan_utils.equal_computations([cond], [rep_cond], return scan_utils.equal_computations([cond], [rep_cond],
node.op.inputs, node.op.inputs,
rep.op.inputs) rep.op.inputs)
return same_cond and (nsteps == rep_nsteps)
def apply(self, fgraph): def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort # Collect all scan nodes ordered according to toposort
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论