提交 240635e0 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5739 from nouiz/scan

Optimization speed up
......@@ -1375,3 +1375,27 @@ def list_of_nodes(inputs, outputs):
lambda o: [inp.owner for inp in o.inputs
if inp.owner and
not any(i in inp.owner.outputs for i in inputs)])
def is_in_ancestors(l_node, f_node):
r"""
Goes up in the graph and returns True if the apply node f_node is found.
Use a stack implementation as the vm algo.
We suppose all nodes are not lazy
(i.e. for IfElse we suppose all inputs are computed)
"""
computed = set()
todo = [l_node]
while todo:
cur = todo.pop()
if cur.outputs[0] in computed:
continue
if all([i in computed or i.owner is None for i in cur.inputs]):
computed.update(cur.outputs)
if cur is f_node:
return True
else:
todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner)
return False
......@@ -2089,13 +2089,7 @@ class TopoOptimizer(NavigatorOptimizer):
if node is not current_node:
q.append(node)
def pruner(node):
if node is not current_node:
try:
q.remove(node)
except ValueError:
pass
u = self.attach_updater(fgraph, importer, pruner,
u = self.attach_updater(fgraph, importer, None,
name=getattr(self, 'name', None))
nb = 0
try:
......@@ -2105,6 +2099,8 @@ class TopoOptimizer(NavigatorOptimizer):
node = q.pop()
else:
node = q.popleft()
if node not in fgraph.apply_nodes:
continue
current_node = node
nb += self.process_node(fgraph, node)
loop_t = time.time() - t0
......@@ -2217,17 +2213,13 @@ class OpKeyOptimizer(NavigatorOptimizer):
if node.op == op:
q.append(node)
def pruner(node):
if node is not current_node and node.op == op:
try:
q.remove(node)
except ValueError:
pass
u = self.attach_updater(fgraph, importer, pruner,
u = self.attach_updater(fgraph, importer, None,
name=getattr(self, 'name', None))
try:
while q:
node = q.pop()
if node not in fgraph.apply_nodes:
continue
current_node = node
self.process_node(fgraph, node)
finally:
......
......@@ -26,7 +26,6 @@ from six import iteritems
from six.moves import xrange
from theano.compile import optdb
from theano.tensor import opt
from theano.scan_module.scan_utils import find_up
from theano.scan_module.scan_utils import clone
......@@ -578,7 +577,7 @@ class CondMerge(gof.Optimizer):
merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]:
if (proposal.inputs[0] == merging_node.inputs[0] and
not find_up(proposal, merging_node)):
not gof.graph.is_in_ancestors(proposal, merging_node)):
# Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
......@@ -683,8 +682,8 @@ def cond_merge_random_op(main_node):
merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]:
if (proposal.inputs[0] == merging_node.inputs[0] and
not find_up(proposal, merging_node) and
not find_up(merging_node, proposal)):
not gof.graph.is_in_ancestors(proposal, merging_node) and
not gof.graph.is_in_ancestors(merging_node, proposal)):
# Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
......
......@@ -70,7 +70,7 @@ from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.scan_module import scan_op
from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import equal_computations, find_up, scan_args
from theano.scan_module.scan_utils import equal_computations, scan_args
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
......@@ -1605,7 +1605,7 @@ class ScanSaveMem(gof.Optimizer):
nw_pos = compress_map[idx]
old_new += [(o, new_outs[nw_pos])]
# Check if the new outputs depend on the old scan node
old_scan_is_used = [scan_utils.find_up(new.owner, node)
old_scan_is_used = [gof.graph.is_in_ancestors(new.owner, node)
for old, new in old_new]
if any(old_scan_is_used):
return False
......@@ -1829,19 +1829,21 @@ class ScanMerge(gof.Optimizer):
except tensor.NotScalarConstantError:
pass
if nsteps != rep_nsteps:
return False
# Check to see if it is an input of a different node
for nd in set_nodes:
if find_up(node, nd) or find_up(nd, node):
if gof.graph.is_in_ancestors(node, nd) or gof.graph.is_in_ancestors(nd, node):
return False
if not node.op.as_while:
return nsteps == rep_nsteps
return True
cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1]
same_cond = scan_utils.equal_computations([cond], [rep_cond],
node.op.inputs,
rep.op.inputs)
return same_cond and (nsteps == rep_nsteps)
return scan_utils.equal_computations([cond], [rep_cond],
node.op.inputs,
rep.op.inputs)
def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort
......
......@@ -1113,20 +1113,6 @@ def compress_outs(op, not_required, inputs):
return (op_inputs, op_outputs, info, node_inputs, map_old_new)
def find_up(l_node, f_node):
r"""
Goes up in the graph and returns True if a node in nodes is found.
"""
if isinstance(l_node, gof.Apply):
l_outs = l_node.outputs
else:
l_outs = l_node
l_ins = gof.graph.inputs(l_outs)
nodes = gof.graph.io_toposort(l_ins, l_outs)
return f_node in nodes
def reconstruct_graph(inputs, outputs, tag=None):
"""
Different interface to clone, that allows you to pass inputs.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论