提交 dd4825e6 authored 作者: Frederic Bastien's avatar Frederic Bastien

Move scan_utils.find_up to gof.graph and optimize it. In one scanmerge case, it…

Move scan_utils.find_up to gof.graph and optimize it. In one scanmerge case, it use 3.6s instead of 15s.
上级 c95f7e8b
......@@ -1375,3 +1375,26 @@ def list_of_nodes(inputs, outputs):
lambda o: [inp.owner for inp in o.inputs
if inp.owner and
not any(i in inp.owner.outputs for i in inputs)])
def is_in_ancestors(l_node, f_node):
r"""
Goes up in the graph and returns True if the apply node f_node is found.
Use a stack implementation as the vm algo.
"""
computed = set()
todo = [l_node]
while todo:
cur = todo.pop()
# We suppose that all outputs are always computed
if cur.outputs[0] in computed:
continue
if all([i in computed or i.owner is None for i in cur.inputs]):
computed.update(cur.outputs)
if cur is f_node:
return True
else:
todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner)
return False
......@@ -26,7 +26,6 @@ from six import iteritems
from six.moves import xrange
from theano.compile import optdb
from theano.tensor import opt
from theano.scan_module.scan_utils import find_up
from theano.scan_module.scan_utils import clone
......@@ -578,7 +577,7 @@ class CondMerge(gof.Optimizer):
merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]:
if (proposal.inputs[0] == merging_node.inputs[0] and
not find_up(proposal, merging_node)):
not gof.graph.is_in_ancestors(proposal, merging_node)):
# Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
......@@ -683,8 +682,8 @@ def cond_merge_random_op(main_node):
merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]:
if (proposal.inputs[0] == merging_node.inputs[0] and
not find_up(proposal, merging_node) and
not find_up(merging_node, proposal)):
not gof.graph.is_in_ancestors(proposal, merging_node) and
not gof.graph.is_in_ancestors(merging_node, proposal)):
# Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
......
......@@ -70,7 +70,7 @@ from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
from theano.scan_module import scan_op
from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import equal_computations, find_up, scan_args
from theano.scan_module.scan_utils import equal_computations, scan_args
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
......@@ -1605,7 +1605,7 @@ class ScanSaveMem(gof.Optimizer):
nw_pos = compress_map[idx]
old_new += [(o, new_outs[nw_pos])]
# Check if the new outputs depend on the old scan node
old_scan_is_used = [scan_utils.find_up(new.owner, node)
old_scan_is_used = [gof.graph.is_in_ancestors(new.owner, node)
for old, new in old_new]
if any(old_scan_is_used):
return False
......@@ -1831,7 +1831,7 @@ class ScanMerge(gof.Optimizer):
# Check to see if it is an input of a different node
for nd in set_nodes:
if find_up(node, nd) or find_up(nd, node):
if gof.graph.is_in_ancestors(node, nd) or gof.graph.is_in_ancestors(nd, node):
return False
if not node.op.as_while:
......
......@@ -1086,20 +1086,6 @@ def compress_outs(op, not_required, inputs):
return (op_inputs, op_outputs, info, node_inputs, map_old_new)
def find_up(l_node, f_node):
r"""
Goes up in the graph and returns True if a node in nodes is found.
"""
if isinstance(l_node, gof.Apply):
l_outs = l_node.outputs
else:
l_outs = l_node
l_ins = gof.graph.inputs(l_outs)
nodes = gof.graph.io_toposort(l_ins, l_outs)
return f_node in nodes
def reconstruct_graph(inputs, outputs, tag=None):
"""
Different interface to clone, that allows you to pass inputs.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论