提交 608c1a07 authored 作者: Joseph Turian's avatar Joseph Turian

merge

......@@ -171,6 +171,10 @@ class Env(utils.object2):
Updates the list of clients of r with new_clients.
"""
if set(r.clients).intersection(set(new_clients)):
print 'RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients]
print 'NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients]
assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients
def __remove_clients__(self, r, clients_to_remove, prune = True):
......@@ -182,6 +186,10 @@ class Env(utils.object2):
"""
for entry in clients_to_remove:
r.clients.remove(entry)
if entry in r.clients:
print 'ENTRY', repr(entry), type(entry[0])
print 'CLIENTS', repr(r.clients)
assert entry not in r.clients # an op,i pair should be unique
if not r.clients:
if prune:
self.__prune_r__([r])
......@@ -194,8 +202,11 @@ class Env(utils.object2):
def __import_r__(self, results):
# Imports the owners of the results
for node in set(r.owner for r in results if r.owner is not None):
self.__import__(node)
r_owner_done = set()
for node in [r.owner for r in results if r.owner is not None]:
if node not in r_owner_done:
r_owner_done.add(node)
self.__import__(node)
for r in results:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
raise TypeError("Undeclared input", r)
......@@ -319,8 +330,8 @@ class Env(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops
return
for node, i in list(r.clients):
assert node == 'output' and self.outputs[i] is r or node.inputs[i] is r
for node, i in list(r.clients): #copy the client list for iteration
assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason)
def replace_all(self, pairs, reason=None):
......
......@@ -187,21 +187,27 @@ class MergeOptimizer(Optimizer):
env.extend(toolbox.ReplaceValidate())
def apply_constant_merge(self, env):
seen_constants = set()
const_sig = _metadict() # result -> result.signature() (for constants)
const_sig_inv = _metadict() # signature -> result (for constants)
for i, c in enumerate([r for r in env.results if isinstance(r, graph.Constant)]):
sig = c.signature()
other_c = const_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
env.replace_validate(c, other_c, reason='Constant Merge')
else:
#this is a new constant
const_sig[c] = sig
const_sig_inv[sig] = c
for node in _list_of_nodes(env):
for i, c in enumerate([r for r in node.inputs if isinstance(r, graph.Constant)]):
if id(c) in seen_constants:
continue
else:
seen_constants.add(id(c))
sig = c.signature()
other_c = const_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
env.replace_validate(c, other_c, reason='Constant Merge')
else:
#this is a new constant
const_sig[c] = sig
const_sig_inv[sig] = c
def exptime_apply_node_merge(self, env):
# we clear the dicts because the Constants signatures are not necessarily hashable
......@@ -242,19 +248,20 @@ class MergeOptimizer(Optimizer):
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Results
nodes_seen = set()
nodes_seen = {}
for node in _list_of_nodes(env):
for node_idx, node in enumerate(_list_of_nodes(env)):
#
# these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself!
#
assert len(node.inputs[0].clients) > 0
assert (node,0) in node.inputs[0].clients
merge_candidates = [c for (c,i) in node.inputs[0].clients if c in nodes_seen]
nodes_seen.add(node)
merge_candidates = [(nodes_seen[c],c) for (c,i) in node.inputs[0].clients if c in nodes_seen]
merge_candidates.sort()
nodes_seen[node] = node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for candidate in merge_candidates:
for candidate_idx, candidate in merge_candidates:
if len(node.inputs) != len(candidate.inputs):
continue
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
......@@ -626,8 +633,8 @@ class NavigatorOptimizer(Optimizer):
def warn(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: print traceback
"""
print "WARNING: Optimization failure due to: ", local_opt
print "TRACEBACK:"
print >> sys.stderr, "WARNING: Optimization failure due to: ", local_opt
print >> sys.stderr, "TRACEBACK:"
traceback.print_exc()
@staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt):
......
......@@ -43,6 +43,38 @@ def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout):
return file
class Event(object):
def __init__(self, kind, node, idx=None, reason=None):
self.kind = kind
if node == 'output':
self.node = 'output'
self.op = 'output'
else:
self.node = node
self.op = node.op
self.idx = idx
self.reason = reason
def __str__(self):
if self.kind == 'change':
return ' '.join(['change',
self.reason,
str(self.op),
str(self.idx),
str(len(self.node.inputs))])
else:
return str(self.__dict__)
def __eq__(self, other):
rval = type(self) == type(other)
if rval:
for attr in ['kind', 'op', 'idx', 'reason']:
rval = rval and getattr(self, attr) == getattr(other, attr)
return rval
def __ne__(self, other):
return not (self == other)
class ResultEquivalenceTracker(object):
def __init__(self):
self.env = None
......@@ -57,12 +89,14 @@ class ResultEquivalenceTracker(object):
self.reasons = {}
self.replaced_by = {}
self.snapshots = {}
self.event_list = []
def on_detach(self, env):
assert env is self.env
self.env = None
def on_prune(self, env, node):
self.event_list.append(Event('prune', node))
#print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes
assert node not in self.inactive_nodes
......@@ -70,6 +104,8 @@ class ResultEquivalenceTracker(object):
self.inactive_nodes.add(node)
def on_import(self, env, node):
self.event_list.append(Event('import', node))
#print 'NEW NODE', node, id(node)
assert node not in self.active_nodes
self.active_nodes.add(node)
......@@ -93,6 +129,7 @@ class ResultEquivalenceTracker(object):
def on_change_input(self, env, node, i, r, new_r, reason=None):
#print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self.event_list.append(Event('change', node, reason=str(reason), idx=i))
self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, [])
......@@ -291,7 +328,6 @@ class OptCheckLinker(OpWiseCLinker):
# because the incorrect result detected here will cause
# subsequent outputs to be incorrect.
raise Exception("OptCheckFailure")
print >> sys.stderr, 'OptCheck PASS'
if 0: #OLD CODE
#print out the summary of the first problematic equivalence group
......@@ -321,7 +357,9 @@ NODEFAULT = ['NODEFAULT']
class OptCheckFunctionMaker(FunctionMaker):
def __init__(self, inputs, outputs, optimizer,
accept_inplace = False, function_builder = Function):
chances_for_optimizer_to_screw_up = 10,
accept_inplace = False,
function_builder = Function):
"""
:type inputs: a list of SymbolicInput instances
......@@ -350,17 +388,39 @@ class OptCheckFunctionMaker(FunctionMaker):
expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], [])
# make the env
env, additional_outputs, equivalence_tracker = optcheck_env(expanded_inputs, outputs, accept_inplace)
self.env = env
for i in xrange(chances_for_optimizer_to_screw_up):
env, additional_outputs, equivalence_tracker = optcheck_env(expanded_inputs, outputs, accept_inplace)
env.equivalence_tracker = equivalence_tracker
# optimize the env
optimizer(env)
if i:
li = env.equivalence_tracker.event_list
l0 = env0.equivalence_tracker.event_list
if li != l0 :
print >> sys.stderr, "WARNING: Optimization process is unstable"
for j in xrange(max(len(li), len(l0))):
if li[j] != l0[j]:
print >> sys.stderr, "* ", j
print >> sys.stderr, " ", str(li[j]) if j < len(li) else '-'
print >> sys.stderr, " ", str(l0[j]) if j < len(l0) else '-'
else:
pass
linker = OptCheckLinker()
print >> sys.stderr, "EXITING"
sys.exit(1)
break
else:
print >> sys.stdout, "OPTCHECK: optimization", i, "of", len(li), "events was stable."
else:
env0 = env
# optimize the env
optimizer(env)
env.equivalence_tracker = equivalence_tracker
del env0
self.env = env
#equivalence_tracker.printstuff()
linker = OptCheckLinker()
#the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer.
no_borrow = [output for output, spec in zip(env.outputs, outputs+additional_outputs) if not spec.borrow]
......@@ -487,11 +547,14 @@ class OptCheck(Mode):
# function_module.function
def function_maker(self, i,o,m, *args, **kwargs):
assert m is self
return OptCheckFunctionMaker(i, o, self.optimizer, *args, **kwargs)
def __init__(self, optimizer='fast_run'):
return OptCheckFunctionMaker(i, o, self.optimizer,
chances_for_optimizer_to_screw_up=self.stability_patience,
*args, **kwargs)
def __init__(self, optimizer='fast_run', stability_patience=10):
super(OptCheck, self).__init__(
optimizer=optimizer,
linker=OptCheckLinker)
self.stability_patience = stability_patience
......@@ -5,7 +5,7 @@
from .. import gof
from ..gof import opt, InconsistencyError, TopoOptimizer
from ..gof import opt, InconsistencyError, TopoOptimizer, graph
from elemwise import Elemwise, DimShuffle
from .. import scalar
import basic as T
......@@ -47,7 +47,7 @@ def insert_inplace_optimizer(env):
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
for node in list(env.nodes):
for node in list(graph.io_toposort(env.inputs, env.outputs)):
op = node.op
if not isinstance(op, Elemwise):
continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论