提交 3c818691 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change merge feature to avoid looping through env.

Replacement for equivalent nodes and constants are now queued when the graph is modified, through the MergeFeature, and processed when the optimization is applied. This means that appling the MergeOptimization when there is nothing to be merged is now almost instantaneous, even for big graphs. It is also faster when lots of nodes have to be merged (empirical 40 % speed-up for a graph with ~ 5000 nodes).
上级 a96d5716
...@@ -204,7 +204,7 @@ optdb.register('merge1', gof.MergeOptimizer(), ...@@ -204,7 +204,7 @@ optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile') 0, 'fast_run', 'fast_compile')
optdb.register('canonicalize', gof.EquilibriumDB(), # rearranges elemwise expressions optdb.register('canonicalize', gof.EquilibriumDB(), # rearranges elemwise expressions
1, 'fast_run', 'fast_compile') 1, 'fast_run', 'fast_compile')
optdb.register('merge1.2', gof.MergeOptimizer(skip_const_merge=False), optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile') 1.2, 'fast_run', 'fast_compile')
optdb.register('Print1.21', PrintCurrentEnv('Post-canonicalize'), optdb.register('Print1.21', PrintCurrentEnv('Post-canonicalize'),
1.21,)# 'fast_run', 'fast_compile') 1.21,)# 'fast_run', 'fast_compile')
......
...@@ -230,6 +230,26 @@ class _metadict: ...@@ -230,6 +230,26 @@ class _metadict:
self.l[i] = (item, value) self.l[i] = (item, value)
return return
self.l.append((item, value)) self.l.append((item, value))
def __delitem__(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
raise KeyError(item)
def discard(self, item):
if item in self.d:
del self.d[item]
else:
for i, (key, val) in enumerate(self.l):
if key == item:
del self.l[i]
return
def get(self, item, default): def get(self, item, default):
try: try:
return self.d[item] return self.d[item]
...@@ -252,6 +272,140 @@ class _metadict: ...@@ -252,6 +272,140 @@ class _metadict:
return "(%s, %s)" % (self.d, self.l) return "(%s, %s)" % (self.d, self.l)
class MergeFeature(object):
"""
Keeps track of variables in env that cannot be merged together.
That way, the MergeOptimizer can remember the result of the last merge
pass on the env.
"""
def on_attach(self, env):
assert not hasattr(env, 'merge_feature')
env.merge_feature = self
## For constants
self.seen_constants = set()
# variable -> signature (for constants)
self.const_sig = _metadict()
# signature -> variable (for constants)
self.const_sig_inv = _metadict()
## For all variables
# Set of distinct (not mergeable) nodes
self.nodes_seen = set()
# Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all
# the outputs of a node with the outputs of a replacement candidate.
# Each node can have several candidates. For instance, if "node" has
# 2 outputs, and there are 3 replacement candidates, we will have:
# shelf.scheduled = [
# [[(node.out1, cand1.out1), (node.out2, cand1.out2)],
# [(node.out1, cand2.out1), (node.out2, cand2.out2)],
# [(node.out1, cand3.out1), (node.out2, cand3.out2)]]]
self.scheduled = []
# List of (node, candidate) pairs, where we tried to replace node by
# candidate, but it failed. This is used to avoid infinite loops.
self.blacklist = []
for node in env.toposort():
self.on_import(env, node)
def on_change_input(self, env, node, i, r, new_r):
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
if node in self.nodes_seen:
self.nodes_seen.discard(node)
self.process_node(env, node)
if isinstance(new_r, graph.Constant):
self.process_constant(env, new_r)
def on_import(self, env, node):
for c in node.inputs:
if isinstance(c, graph.Constant):
self.process_constant(env, c)
self.process_node(env, node)
def on_prune(self, env, node):
self.nodes_seen.discard(node)
for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
# This was the last node using this constant
sig = self.const_sig[c]
self.const_sig.discard(c)
self.const_sig_inv.discard(sig)
self.seen_constants.discard(id(c))
def process_constant(self, env, c):
"""Check if a constant can be merged, and queue that replacement"""
if id(c) in self.seen_constants:
return
sig = c.signature()
other_c = self.const_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
self.scheduled.append([[(c, other_c)]])
else:
#this is a new constant
self.const_sig[c] = sig
self.const_sig_inv[sig] = c
self.seen_constants.add(id(c))
def process_node(self, env, node):
"""Check if a node can be merged, and queue that replacement."""
if node in self.nodes_seen:
return
# These asserts ensure that the env has set the clients field properly.
# The clients should at least contain `node` itself!
if node.inputs:
assert len(node.inputs[0].clients) > 0
assert (node, 0) in node.inputs[0].clients
merge_candidates = [c for (c, i) in node.inputs[0].clients
if c in self.nodes_seen]
else:
merge_candidates = []
replacement_candidates = []
for candidate in merge_candidates:
if candidate is node:
continue
if len(node.inputs) != len(candidate.inputs):
continue
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op:
if (node, candidate) in self.blacklist:
# They were already tried, and there was an error
continue
# Schedule transfer of clients from node to candidate
#self.nodes_scheduled.setdefault(node, [])
#self.nodes_scheduled[node].append(candidate)
pairs = zip(node.outputs, candidate.outputs)
#transfer names
for node_output, cand_output in pairs:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if node_output.name:
cand_output.name = node_output.name
replacement_candidates.append(pairs)
if replacement_candidates:
self.scheduled.append(replacement_candidates)
else:
self.nodes_seen.add(node)
class MergeOptimizer(Optimizer): class MergeOptimizer(Optimizer):
""" """
Merges parts of the graph that are identical and redundant. Merges parts of the graph that are identical and redundant.
...@@ -264,94 +418,34 @@ class MergeOptimizer(Optimizer): ...@@ -264,94 +418,34 @@ class MergeOptimizer(Optimizer):
The first step of merging is constant-merging, so that all clients of an The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1). int(1) for example, are transferred to a particular instance of int(1).
""" """
def __init__(self, skip_const_merge=False): def __init__(self):
self.skip_const_merge = skip_const_merge Optimizer.__init__(self)
def add_requirements(self, env): def add_requirements(self, env):
# Added by default # Added by default
#env.extend(toolbox.ReplaceValidate()) #env.extend(toolbox.ReplaceValidate())
pass if not hasattr(env, 'merge_feature'):
env.extend(MergeFeature())
def apply_constant_merge(self, env):
seen_constants = set()
const_sig = _metadict() # variable -> variable.signature() (for constants)
const_sig_inv = _metadict() # signature -> variable (for constants)
for node in _list_of_nodes(env):
for i, c in enumerate([r for r in node.inputs if isinstance(r, graph.Constant)]):
if id(c) in seen_constants:
continue
else:
seen_constants.add(id(c))
sig = c.signature()
other_c = const_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
env.replace_validate(c, other_c, reason='Constant Merge')
else:
#this is a new constant
const_sig[c] = sig
const_sig_inv[sig] = c
def apply_node_merge(self, env):
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Variables
nodes_seen = {}
for node_idx, node in enumerate(_list_of_nodes(env)):
#
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
#
if node.inputs:
assert len(node.inputs[0].clients) > 0
assert (node,0) in node.inputs[0].clients
merge_candidates = [(nodes_seen[c],c) for (c,i) in node.inputs[0].clients if c in nodes_seen]
else:
merge_candidates = []
merge_candidates.sort()
nodes_seen[node] = node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for candidate_idx, candidate in merge_candidates:
if len(node.inputs) != len(candidate.inputs):
continue
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op:
assert node is not candidate
#
#transfer clients from node to candidate
#
success = True
assert len(node.outputs) == len(candidate.outputs)
pairs = zip(node.outputs, candidate.outputs)
#transfer names
for node_output, cand_output in pairs:
#clobber old name with new one
#it's arbitrary... one of the names has to go
if node_output.name:
cand_output.name = node_output.name
try:
env.replace_all_validate(pairs, reason="Merge")
except InconsistencyError, e:
success = False
if success:
#break out of the candidate loop
break
else:
#try the next candidate
pass
#TODO: Consider splitting this into a separate optimizer (SeqOptimizer)
def apply(self, env): def apply(self, env):
if not self.skip_const_merge: # Constant and non-constant are now applied in the same phase.
self.apply_constant_merge(env) # I am not sure why, but it seems to be faster this way.
self.apply_node_merge(env) sched = env.merge_feature.scheduled
while sched:
pairs_list = sched.pop()
success = True
for pairs in pairs_list:
try:
env.replace_all_validate(pairs, 'Merge')
except InconsistencyError:
success = False
env.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner))
if success:
break
# clear blacklist
env.merge_feature.blacklist = []
merge_optimizer = MergeOptimizer() merge_optimizer = MergeOptimizer()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论