提交 3c818691 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change merge feature to avoid looping through env.

Replacement for equivalent nodes and constants are now queued when the graph is modified, through the MergeFeature, and processed when the optimization is applied. This means that appling the MergeOptimization when there is nothing to be merged is now almost instantaneous, even for big graphs. It is also faster when lots of nodes have to be merged (empirical 40 % speed-up for a graph with ~ 5000 nodes).
上级 a96d5716
...@@ -204,7 +204,7 @@ optdb.register('merge1', gof.MergeOptimizer(), ...@@ -204,7 +204,7 @@ optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile') 0, 'fast_run', 'fast_compile')
optdb.register('canonicalize', gof.EquilibriumDB(), # rearranges elemwise expressions optdb.register('canonicalize', gof.EquilibriumDB(), # rearranges elemwise expressions
1, 'fast_run', 'fast_compile') 1, 'fast_run', 'fast_compile')
optdb.register('merge1.2', gof.MergeOptimizer(skip_const_merge=False), optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile') 1.2, 'fast_run', 'fast_compile')
optdb.register('Print1.21', PrintCurrentEnv('Post-canonicalize'), optdb.register('Print1.21', PrintCurrentEnv('Post-canonicalize'),
1.21,)# 'fast_run', 'fast_compile') 1.21,)# 'fast_run', 'fast_compile')
......
...@@ -230,6 +230,26 @@ class _metadict: ...@@ -230,6 +230,26 @@ class _metadict:
self.l[i] = (item, value) self.l[i] = (item, value)
return return
self.l.append((item, value)) self.l.append((item, value))
def __delitem__(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
raise KeyError(item)
def discard(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
def get(self, item, default): def get(self, item, default):
try: try:
return self.d[item] return self.d[item]
...@@ -252,80 +272,123 @@ class _metadict: ...@@ -252,80 +272,123 @@ class _metadict:
return "(%s, %s)" % (self.d, self.l) return "(%s, %s)" % (self.d, self.l)
class MergeOptimizer(Optimizer): class MergeFeature(object):
""" """
Merges parts of the graph that are identical and redundant. Keeps track of variables in env that cannot be merged together.
The basic principle is that if two Applies have ops that compare equal, and That way, the MergeOptimizer can remember the result of the last merge
identical inputs, then they do not both need to be computed. The clients of pass on the env.
one are transferred to the other and one of them is removed from the graph.
This procedure is carried out in input->output order through the graph.
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
""" """
def __init__(self, skip_const_merge=False): def on_attach(self, env):
self.skip_const_merge = skip_const_merge assert not hasattr(env, 'merge_feature')
env.merge_feature = self
## For constants
self.seen_constants = set()
# variable -> signature (for constants)
self.const_sig = _metadict()
# signature -> variable (for constants)
self.const_sig_inv = _metadict()
## For all variables
# Set of distinct (not mergeable) nodes
self.nodes_seen = set()
# Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all
# the outputs of a node with the outputs of a replacement candidate.
# Each node can have several candidates. For instance, if "node" has
# 2 outputs, and there are 3 replacement candidates, we will have:
# shelf.scheduled = [
# [[(node.out1, cand1.out1), (node.out2, cand1.out2)],
# [(node.out1, cand2.out1), (node.out2, cand2.out2)],
# [(node.out1, cand3.out1), (node.out2, cand3.out2)]]]
self.scheduled = []
# List of (node, candidate) pairs, where we tried to replace node by
# candidate, but it failed. This is used to avoid infinite loops.
self.blacklist = []
for node in env.toposort():
self.on_import(env, node)
def add_requirements(self, env): def on_change_input(self, env, node, i, r, new_r):
# Added by default # If inputs to node change, it is not guaranteed that it is distinct
#env.extend(toolbox.ReplaceValidate()) # from the other nodes in nodes_seen
pass if node in self.nodes_seen:
self.nodes_seen.discard(node)
self.process_node(env, node)
def apply_constant_merge(self, env): if isinstance(new_r, graph.Constant):
seen_constants = set() self.process_constant(env, new_r)
const_sig = _metadict() # variable -> variable.signature() (for constants)
const_sig_inv = _metadict() # signature -> variable (for constants) def on_import(self, env, node):
for node in _list_of_nodes(env): for c in node.inputs:
for i, c in enumerate([r for r in node.inputs if isinstance(r, graph.Constant)]): if isinstance(c, graph.Constant):
if id(c) in seen_constants: self.process_constant(env, c)
continue
else: self.process_node(env, node)
seen_constants.add(id(c))
def on_prune(self, env, node):
self.nodes_seen.discard(node)
for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
# This was the last node using this constant
sig = self.const_sig[c]
self.const_sig.discard(c)
self.const_sig_inv.discard(sig)
self.seen_constants.discard(id(c))
def process_constant(self, env, c):
"""Check if a constant can be merged, and queue that replacement"""
if id(c) in self.seen_constants:
return
sig = c.signature() sig = c.signature()
other_c = const_sig_inv.get(sig, None) other_c = self.const_sig_inv.get(sig, None)
if other_c is not None: if other_c is not None:
# multiple names will clobber each other.. # multiple names will clobber each other..
# we adopt convention to keep the last name # we adopt convention to keep the last name
if c.name: if c.name:
other_c.name = c.name other_c.name = c.name
env.replace_validate(c, other_c, reason='Constant Merge') self.scheduled.append([[(c, other_c)]])
else: else:
#this is a new constant #this is a new constant
const_sig[c] = sig self.const_sig[c] = sig
const_sig_inv[sig] = c self.const_sig_inv[sig] = c
self.seen_constants.add(id(c))
def apply_node_merge(self, env): def process_node(self, env, node):
# we clear the dicts because the Constants signatures are not necessarily hashable """Check if a node can be merged, and queue that replacement."""
# and it's more efficient to give them an integer like the other Variables if node in self.nodes_seen:
return
nodes_seen = {}
for node_idx, node in enumerate(_list_of_nodes(env)): # These asserts ensure that the env has set the clients field properly.
# # The clients should at least contain `node` itself!
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
#
if node.inputs: if node.inputs:
assert len(node.inputs[0].clients) > 0 assert len(node.inputs[0].clients) > 0
assert (node,0) in node.inputs[0].clients assert (node, 0) in node.inputs[0].clients
merge_candidates = [(nodes_seen[c],c) for (c,i) in node.inputs[0].clients if c in nodes_seen] merge_candidates = [c for (c, i) in node.inputs[0].clients
if c in self.nodes_seen]
else: else:
merge_candidates = [] merge_candidates = []
merge_candidates.sort()
nodes_seen[node] = node_idx replacement_candidates = []
#print 'NODE', node, merge_candidates, node.inputs[0].clients for candidate in merge_candidates:
for candidate_idx, candidate in merge_candidates: if candidate is node:
continue
if len(node.inputs) != len(candidate.inputs): if len(node.inputs) != len(candidate.inputs):
continue continue
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op: if inputs_match and node.op == candidate.op:
assert node is not candidate if (node, candidate) in self.blacklist:
# # They were already tried, and there was an error
#transfer clients from node to candidate continue
#
success = True # Schedule transfer of clients from node to candidate
assert len(node.outputs) == len(candidate.outputs) #self.nodes_scheduled.setdefault(node, [])
#self.nodes_scheduled[node].append(candidate)
pairs = zip(node.outputs, candidate.outputs) pairs = zip(node.outputs, candidate.outputs)
#transfer names #transfer names
...@@ -334,24 +397,55 @@ class MergeOptimizer(Optimizer): ...@@ -334,24 +397,55 @@ class MergeOptimizer(Optimizer):
#it's arbitrary... one of the names has to go #it's arbitrary... one of the names has to go
if node_output.name: if node_output.name:
cand_output.name = node_output.name cand_output.name = node_output.name
try:
env.replace_all_validate(pairs, reason="Merge")
except InconsistencyError, e:
success = False
if success: replacement_candidates.append(pairs)
#break out of the candidate loop
break if replacement_candidates:
self.scheduled.append(replacement_candidates)
else: else:
#try the next candidate self.nodes_seen.add(node)
pass
class MergeOptimizer(Optimizer):
"""
Merges parts of the graph that are identical and redundant.
The basic principle is that if two Applies have ops that compare equal, and
identical inputs, then they do not both need to be computed. The clients of
one are transferred to the other and one of them is removed from the graph.
This procedure is carried out in input->output order through the graph.
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
"""
def __init__(self):
Optimizer.__init__(self)
def add_requirements(self, env):
# Added by default
#env.extend(toolbox.ReplaceValidate())
if not hasattr(env, 'merge_feature'):
env.extend(MergeFeature())
#TODO: Consider splitting this into a separate optimizer (SeqOptimizer)
def apply(self, env): def apply(self, env):
if not self.skip_const_merge: # Constant and non-constant are now applied in the same phase.
self.apply_constant_merge(env) # I am not sure why, but it seems to be faster this way.
self.apply_node_merge(env) sched = env.merge_feature.scheduled
while sched:
pairs_list = sched.pop()
success = True
for pairs in pairs_list:
try:
env.replace_all_validate(pairs, 'Merge')
except InconsistencyError:
success = False
env.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner))
if success:
break
# clear blacklist
env.merge_feature.blacklist = []
merge_optimizer = MergeOptimizer() merge_optimizer = MergeOptimizer()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论