提交 3c818691 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change merge feature to avoid looping through env.

Replacement for equivalent nodes and constants are now queued when the graph is modified, through the MergeFeature, and processed when the optimization is applied. This means that appling the MergeOptimization when there is nothing to be merged is now almost instantaneous, even for big graphs. It is also faster when lots of nodes have to be merged (empirical 40 % speed-up for a graph with ~ 5000 nodes).
上级 a96d5716
......@@ -204,7 +204,7 @@ optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile')
optdb.register('canonicalize', gof.EquilibriumDB(), # rearranges elemwise expressions
1, 'fast_run', 'fast_compile')
optdb.register('merge1.2', gof.MergeOptimizer(skip_const_merge=False),
optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile')
optdb.register('Print1.21', PrintCurrentEnv('Post-canonicalize'),
1.21,)# 'fast_run', 'fast_compile')
......
......@@ -230,6 +230,26 @@ class _metadict:
self.l[i] = (item, value)
return
self.l.append((item, value))
def __delitem__(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
raise KeyError(item)
def discard(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
def get(self, item, default):
try:
return self.d[item]
......@@ -252,80 +272,123 @@ class _metadict:
return "(%s, %s)" % (self.d, self.l)
class MergeOptimizer(Optimizer):
class MergeFeature(object):
"""
Merges parts of the graph that are identical and redundant.
Keeps track of variables in env that cannot be merged together.
The basic principle is that if two Applies have ops that compare equal, and
identical inputs, then they do not both need to be computed. The clients of
one are transferred to the other and one of them is removed from the graph.
This procedure is carried out in input->output order through the graph.
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
That way, the MergeOptimizer can remember the result of the last merge
pass on the env.
"""
def __init__(self, skip_const_merge=False):
self.skip_const_merge = skip_const_merge
def on_attach(self, env):
assert not hasattr(env, 'merge_feature')
env.merge_feature = self
## For constants
self.seen_constants = set()
# variable -> signature (for constants)
self.const_sig = _metadict()
# signature -> variable (for constants)
self.const_sig_inv = _metadict()
## For all variables
# Set of distinct (not mergeable) nodes
self.nodes_seen = set()
# Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all
# the outputs of a node with the outputs of a replacement candidate.
# Each node can have several candidates. For instance, if "node" has
# 2 outputs, and there are 3 replacement candidates, we will have:
# shelf.scheduled = [
# [[(node.out1, cand1.out1), (node.out2, cand1.out2)],
# [(node.out1, cand2.out1), (node.out2, cand2.out2)],
# [(node.out1, cand3.out1), (node.out2, cand3.out2)]]]
self.scheduled = []
# List of (node, candidate) pairs, where we tried to replace node by
# candidate, but it failed. This is used to avoid infinite loops.
self.blacklist = []
for node in env.toposort():
self.on_import(env, node)
def add_requirements(self, env):
# Added by default
#env.extend(toolbox.ReplaceValidate())
pass
def on_change_input(self, env, node, i, r, new_r):
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
if node in self.nodes_seen:
self.nodes_seen.discard(node)
self.process_node(env, node)
def apply_constant_merge(self, env):
seen_constants = set()
const_sig = _metadict() # variable -> variable.signature() (for constants)
const_sig_inv = _metadict() # signature -> variable (for constants)
for node in _list_of_nodes(env):
for i, c in enumerate([r for r in node.inputs if isinstance(r, graph.Constant)]):
if id(c) in seen_constants:
continue
else:
seen_constants.add(id(c))
if isinstance(new_r, graph.Constant):
self.process_constant(env, new_r)
def on_import(self, env, node):
for c in node.inputs:
if isinstance(c, graph.Constant):
self.process_constant(env, c)
self.process_node(env, node)
def on_prune(self, env, node):
self.nodes_seen.discard(node)
for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
# This was the last node using this constant
sig = self.const_sig[c]
self.const_sig.discard(c)
self.const_sig_inv.discard(sig)
self.seen_constants.discard(id(c))
def process_constant(self, env, c):
"""Check if a constant can be merged, and queue that replacement"""
if id(c) in self.seen_constants:
return
sig = c.signature()
other_c = const_sig_inv.get(sig, None)
other_c = self.const_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
env.replace_validate(c, other_c, reason='Constant Merge')
self.scheduled.append([[(c, other_c)]])
else:
#this is a new constant
const_sig[c] = sig
const_sig_inv[sig] = c
self.const_sig[c] = sig
self.const_sig_inv[sig] = c
self.seen_constants.add(id(c))
def apply_node_merge(self, env):
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Variables
nodes_seen = {}
def process_node(self, env, node):
"""Check if a node can be merged, and queue that replacement."""
if node in self.nodes_seen:
return
for node_idx, node in enumerate(_list_of_nodes(env)):
#
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
#
# These asserts ensure that the env has set the clients field properly.
# The clients should at least contain `node` itself!
if node.inputs:
assert len(node.inputs[0].clients) > 0
assert (node,0) in node.inputs[0].clients
merge_candidates = [(nodes_seen[c],c) for (c,i) in node.inputs[0].clients if c in nodes_seen]
assert (node, 0) in node.inputs[0].clients
merge_candidates = [c for (c, i) in node.inputs[0].clients
if c in self.nodes_seen]
else:
merge_candidates = []
merge_candidates.sort()
nodes_seen[node] = node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for candidate_idx, candidate in merge_candidates:
replacement_candidates = []
for candidate in merge_candidates:
if candidate is node:
continue
if len(node.inputs) != len(candidate.inputs):
continue
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op:
assert node is not candidate
#
#transfer clients from node to candidate
#
success = True
assert len(node.outputs) == len(candidate.outputs)
if (node, candidate) in self.blacklist:
# They were already tried, and there was an error
continue
# Schedule transfer of clients from node to candidate
#self.nodes_scheduled.setdefault(node, [])
#self.nodes_scheduled[node].append(candidate)
pairs = zip(node.outputs, candidate.outputs)
#transfer names
......@@ -334,24 +397,55 @@ class MergeOptimizer(Optimizer):
#it's arbitrary... one of the names has to go
if node_output.name:
cand_output.name = node_output.name
try:
env.replace_all_validate(pairs, reason="Merge")
except InconsistencyError, e:
success = False
if success:
#break out of the candidate loop
break
replacement_candidates.append(pairs)
if replacement_candidates:
self.scheduled.append(replacement_candidates)
else:
#try the next candidate
pass
self.nodes_seen.add(node)
class MergeOptimizer(Optimizer):
"""
Merges parts of the graph that are identical and redundant.
The basic principle is that if two Applies have ops that compare equal, and
identical inputs, then they do not both need to be computed. The clients of
one are transferred to the other and one of them is removed from the graph.
This procedure is carried out in input->output order through the graph.
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
"""
def __init__(self):
Optimizer.__init__(self)
def add_requirements(self, env):
# Added by default
#env.extend(toolbox.ReplaceValidate())
if not hasattr(env, 'merge_feature'):
env.extend(MergeFeature())
#TODO: Consider splitting this into a separate optimizer (SeqOptimizer)
def apply(self, env):
if not self.skip_const_merge:
self.apply_constant_merge(env)
self.apply_node_merge(env)
# Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way.
sched = env.merge_feature.scheduled
while sched:
pairs_list = sched.pop()
success = True
for pairs in pairs_list:
try:
env.replace_all_validate(pairs, 'Merge')
except InconsistencyError:
success = False
env.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner))
if success:
break
# clear blacklist
env.merge_feature.blacklist = []
merge_optimizer = MergeOptimizer()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论