提交 608c1a07 authored 作者: Joseph Turian's avatar Joseph Turian

merge

...@@ -171,6 +171,10 @@ class Env(utils.object2): ...@@ -171,6 +171,10 @@ class Env(utils.object2):
Updates the list of clients of r with new_clients. Updates the list of clients of r with new_clients.
""" """
if set(r.clients).intersection(set(new_clients)):
print 'RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients]
print 'NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients]
assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients r.clients += new_clients
def __remove_clients__(self, r, clients_to_remove, prune = True): def __remove_clients__(self, r, clients_to_remove, prune = True):
...@@ -182,6 +186,10 @@ class Env(utils.object2): ...@@ -182,6 +186,10 @@ class Env(utils.object2):
""" """
for entry in clients_to_remove: for entry in clients_to_remove:
r.clients.remove(entry) r.clients.remove(entry)
if entry in r.clients:
print 'ENTRY', repr(entry), type(entry[0])
print 'CLIENTS', repr(r.clients)
assert entry not in r.clients # an op,i pair should be unique
if not r.clients: if not r.clients:
if prune: if prune:
self.__prune_r__([r]) self.__prune_r__([r])
...@@ -194,8 +202,11 @@ class Env(utils.object2): ...@@ -194,8 +202,11 @@ class Env(utils.object2):
def __import_r__(self, results): def __import_r__(self, results):
# Imports the owners of the results # Imports the owners of the results
for node in set(r.owner for r in results if r.owner is not None): r_owner_done = set()
self.__import__(node) for node in [r.owner for r in results if r.owner is not None]:
if node not in r_owner_done:
r_owner_done.add(node)
self.__import__(node)
for r in results: for r in results:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
raise TypeError("Undeclared input", r) raise TypeError("Undeclared input", r)
...@@ -319,8 +330,8 @@ class Env(utils.object2): ...@@ -319,8 +330,8 @@ class Env(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops # because it makes it easier to implement some optimizations for multiple-output ops
return return
for node, i in list(r.clients): for node, i in list(r.clients): #copy the client list for iteration
assert node == 'output' and self.outputs[i] is r or node.inputs[i] is r assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason) self.change_input(node, i, new_r, reason=reason)
def replace_all(self, pairs, reason=None): def replace_all(self, pairs, reason=None):
......
...@@ -187,21 +187,27 @@ class MergeOptimizer(Optimizer): ...@@ -187,21 +187,27 @@ class MergeOptimizer(Optimizer):
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
def apply_constant_merge(self, env): def apply_constant_merge(self, env):
seen_constants = set()
const_sig = _metadict() # result -> result.signature() (for constants) const_sig = _metadict() # result -> result.signature() (for constants)
const_sig_inv = _metadict() # signature -> result (for constants) const_sig_inv = _metadict() # signature -> result (for constants)
for i, c in enumerate([r for r in env.results if isinstance(r, graph.Constant)]): for node in _list_of_nodes(env):
sig = c.signature() for i, c in enumerate([r for r in node.inputs if isinstance(r, graph.Constant)]):
other_c = const_sig_inv.get(sig, None) if id(c) in seen_constants:
if other_c is not None: continue
# multiple names will clobber each other.. else:
# we adopt convention to keep the last name seen_constants.add(id(c))
if c.name: sig = c.signature()
other_c.name = c.name other_c = const_sig_inv.get(sig, None)
env.replace_validate(c, other_c, reason='Constant Merge') if other_c is not None:
else: # multiple names will clobber each other..
#this is a new constant # we adopt convention to keep the last name
const_sig[c] = sig if c.name:
const_sig_inv[sig] = c other_c.name = c.name
env.replace_validate(c, other_c, reason='Constant Merge')
else:
#this is a new constant
const_sig[c] = sig
const_sig_inv[sig] = c
def exptime_apply_node_merge(self, env): def exptime_apply_node_merge(self, env):
# we clear the dicts because the Constants signatures are not necessarily hashable # we clear the dicts because the Constants signatures are not necessarily hashable
...@@ -242,19 +248,20 @@ class MergeOptimizer(Optimizer): ...@@ -242,19 +248,20 @@ class MergeOptimizer(Optimizer):
# we clear the dicts because the Constants signatures are not necessarily hashable # we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Results # and it's more efficient to give them an integer like the other Results
nodes_seen = set() nodes_seen = {}
for node in _list_of_nodes(env): for node_idx, node in enumerate(_list_of_nodes(env)):
# #
# these asserts ensure that the env has set the clients field properly the clients # these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself! # should at least contain `node` itself!
# #
assert len(node.inputs[0].clients) > 0 assert len(node.inputs[0].clients) > 0
assert (node,0) in node.inputs[0].clients assert (node,0) in node.inputs[0].clients
merge_candidates = [c for (c,i) in node.inputs[0].clients if c in nodes_seen] merge_candidates = [(nodes_seen[c],c) for (c,i) in node.inputs[0].clients if c in nodes_seen]
nodes_seen.add(node) merge_candidates.sort()
nodes_seen[node] = node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients #print 'NODE', node, merge_candidates, node.inputs[0].clients
for candidate in merge_candidates: for candidate_idx, candidate in merge_candidates:
if len(node.inputs) != len(candidate.inputs): if len(node.inputs) != len(candidate.inputs):
continue continue
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs)) inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
...@@ -626,8 +633,8 @@ class NavigatorOptimizer(Optimizer): ...@@ -626,8 +633,8 @@ class NavigatorOptimizer(Optimizer):
def warn(exc, nav, repl_pairs, local_opt): def warn(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: print traceback """failure_callback for NavigatorOptimizer: print traceback
""" """
print "WARNING: Optimization failure due to: ", local_opt print >> sys.stderr, "WARNING: Optimization failure due to: ", local_opt
print "TRACEBACK:" print >> sys.stderr, "TRACEBACK:"
traceback.print_exc() traceback.print_exc()
@staticmethod @staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt): def warn_inplace(exc, nav, repl_pairs, local_opt):
......
...@@ -43,6 +43,38 @@ def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout): ...@@ -43,6 +43,38 @@ def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout):
return file return file
class Event(object):
def __init__(self, kind, node, idx=None, reason=None):
self.kind = kind
if node == 'output':
self.node = 'output'
self.op = 'output'
else:
self.node = node
self.op = node.op
self.idx = idx
self.reason = reason
def __str__(self):
if self.kind == 'change':
return ' '.join(['change',
self.reason,
str(self.op),
str(self.idx),
str(len(self.node.inputs))])
else:
return str(self.__dict__)
def __eq__(self, other):
rval = type(self) == type(other)
if rval:
for attr in ['kind', 'op', 'idx', 'reason']:
rval = rval and getattr(self, attr) == getattr(other, attr)
return rval
def __ne__(self, other):
return not (self == other)
class ResultEquivalenceTracker(object): class ResultEquivalenceTracker(object):
def __init__(self): def __init__(self):
self.env = None self.env = None
...@@ -57,12 +89,14 @@ class ResultEquivalenceTracker(object): ...@@ -57,12 +89,14 @@ class ResultEquivalenceTracker(object):
self.reasons = {} self.reasons = {}
self.replaced_by = {} self.replaced_by = {}
self.snapshots = {} self.snapshots = {}
self.event_list = []
def on_detach(self, env): def on_detach(self, env):
assert env is self.env assert env is self.env
self.env = None self.env = None
def on_prune(self, env, node): def on_prune(self, env, node):
self.event_list.append(Event('prune', node))
#print 'PRUNING NODE', node, id(node) #print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes assert node in self.active_nodes
assert node not in self.inactive_nodes assert node not in self.inactive_nodes
...@@ -70,6 +104,8 @@ class ResultEquivalenceTracker(object): ...@@ -70,6 +104,8 @@ class ResultEquivalenceTracker(object):
self.inactive_nodes.add(node) self.inactive_nodes.add(node)
def on_import(self, env, node): def on_import(self, env, node):
self.event_list.append(Event('import', node))
#print 'NEW NODE', node, id(node) #print 'NEW NODE', node, id(node)
assert node not in self.active_nodes assert node not in self.active_nodes
self.active_nodes.add(node) self.active_nodes.add(node)
...@@ -93,6 +129,7 @@ class ResultEquivalenceTracker(object): ...@@ -93,6 +129,7 @@ class ResultEquivalenceTracker(object):
def on_change_input(self, env, node, i, r, new_r, reason=None): def on_change_input(self, env, node, i, r, new_r, reason=None):
#print 'CHANGE by', reason, 'to use', new_r, type(new_r) #print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self.event_list.append(Event('change', node, reason=str(reason), idx=i))
self.reasons.setdefault(new_r, []) self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, []) self.replaced_by.setdefault(new_r, [])
...@@ -291,7 +328,6 @@ class OptCheckLinker(OpWiseCLinker): ...@@ -291,7 +328,6 @@ class OptCheckLinker(OpWiseCLinker):
# because the incorrect result detected here will cause # because the incorrect result detected here will cause
# subsequent outputs to be incorrect. # subsequent outputs to be incorrect.
raise Exception("OptCheckFailure") raise Exception("OptCheckFailure")
print >> sys.stderr, 'OptCheck PASS'
if 0: #OLD CODE if 0: #OLD CODE
#print out the summary of the first problematic equivalence group #print out the summary of the first problematic equivalence group
...@@ -321,7 +357,9 @@ NODEFAULT = ['NODEFAULT'] ...@@ -321,7 +357,9 @@ NODEFAULT = ['NODEFAULT']
class OptCheckFunctionMaker(FunctionMaker): class OptCheckFunctionMaker(FunctionMaker):
def __init__(self, inputs, outputs, optimizer, def __init__(self, inputs, outputs, optimizer,
accept_inplace = False, function_builder = Function): chances_for_optimizer_to_screw_up = 10,
accept_inplace = False,
function_builder = Function):
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
...@@ -350,17 +388,39 @@ class OptCheckFunctionMaker(FunctionMaker): ...@@ -350,17 +388,39 @@ class OptCheckFunctionMaker(FunctionMaker):
expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], []) expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], [])
# make the env # make the env
env, additional_outputs, equivalence_tracker = optcheck_env(expanded_inputs, outputs, accept_inplace) for i in xrange(chances_for_optimizer_to_screw_up):
self.env = env env, additional_outputs, equivalence_tracker = optcheck_env(expanded_inputs, outputs, accept_inplace)
env.equivalence_tracker = equivalence_tracker
# optimize the env
optimizer(env)
if i:
li = env.equivalence_tracker.event_list
l0 = env0.equivalence_tracker.event_list
if li != l0 :
print >> sys.stderr, "WARNING: Optimization process is unstable"
for j in xrange(max(len(li), len(l0))):
if li[j] != l0[j]:
print >> sys.stderr, "* ", j
print >> sys.stderr, " ", str(li[j]) if j < len(li) else '-'
print >> sys.stderr, " ", str(l0[j]) if j < len(l0) else '-'
else:
pass
linker = OptCheckLinker() print >> sys.stderr, "EXITING"
sys.exit(1)
break
else:
print >> sys.stdout, "OPTCHECK: optimization", i, "of", len(li), "events was stable."
else:
env0 = env
# optimize the env
optimizer(env)
env.equivalence_tracker = equivalence_tracker del env0
self.env = env
#equivalence_tracker.printstuff() #equivalence_tracker.printstuff()
linker = OptCheckLinker()
#the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer. #the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer.
no_borrow = [output for output, spec in zip(env.outputs, outputs+additional_outputs) if not spec.borrow] no_borrow = [output for output, spec in zip(env.outputs, outputs+additional_outputs) if not spec.borrow]
...@@ -487,11 +547,14 @@ class OptCheck(Mode): ...@@ -487,11 +547,14 @@ class OptCheck(Mode):
# function_module.function # function_module.function
def function_maker(self, i,o,m, *args, **kwargs): def function_maker(self, i,o,m, *args, **kwargs):
assert m is self assert m is self
return OptCheckFunctionMaker(i, o, self.optimizer, *args, **kwargs) return OptCheckFunctionMaker(i, o, self.optimizer,
def __init__(self, optimizer='fast_run'): chances_for_optimizer_to_screw_up=self.stability_patience,
*args, **kwargs)
def __init__(self, optimizer='fast_run', stability_patience=10):
super(OptCheck, self).__init__( super(OptCheck, self).__init__(
optimizer=optimizer, optimizer=optimizer,
linker=OptCheckLinker) linker=OptCheckLinker)
self.stability_patience = stability_patience
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from .. import gof from .. import gof
from ..gof import opt, InconsistencyError, TopoOptimizer from ..gof import opt, InconsistencyError, TopoOptimizer, graph
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from .. import scalar from .. import scalar
import basic as T import basic as T
...@@ -47,7 +47,7 @@ def insert_inplace_optimizer(env): ...@@ -47,7 +47,7 @@ def insert_inplace_optimizer(env):
x + y + z -> x += y += z x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y) (x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
""" """
for node in list(env.nodes): for node in list(graph.io_toposort(env.inputs, env.outputs)):
op = node.op op = node.op
if not isinstance(op, Elemwise): if not isinstance(op, Elemwise):
continue continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论