提交 824f5fa7 authored 作者: James Bergstra's avatar James Bergstra

added Event to debugmode to catch optimization instability better

上级 750d96fd
...@@ -43,6 +43,38 @@ def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout): ...@@ -43,6 +43,38 @@ def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout):
return file return file
class Event(object):
def __init__(self, kind, node, idx=None, reason=None):
self.kind = kind
if node == 'output':
self.node = 'output'
self.op = 'output'
else:
self.node = node
self.op = node.op
self.idx = idx
self.reason = reason
def __str__(self):
if self.kind == 'change':
return ' '.join(['change',
self.reason,
str(self.op),
str(self.idx),
str(len(self.node.inputs))])
else:
return str(self.__dict__)
def __eq__(self, other):
rval = type(self) == type(other)
if rval:
for attr in ['kind', 'op', 'idx', 'reason']:
rval = rval and getattr(self, attr) == getattr(other, attr)
return rval
def __ne__(self, other):
return not (self == other)
class ResultEquivalenceTracker(object): class ResultEquivalenceTracker(object):
def __init__(self): def __init__(self):
self.env = None self.env = None
...@@ -57,12 +89,14 @@ class ResultEquivalenceTracker(object): ...@@ -57,12 +89,14 @@ class ResultEquivalenceTracker(object):
self.reasons = {} self.reasons = {}
self.replaced_by = {} self.replaced_by = {}
self.snapshots = {} self.snapshots = {}
self.event_list = []
def on_detach(self, env): def on_detach(self, env):
assert env is self.env assert env is self.env
self.env = None self.env = None
def on_prune(self, env, node): def on_prune(self, env, node):
self.event_list.append(Event('prune', node))
#print 'PRUNING NODE', node, id(node) #print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes assert node in self.active_nodes
assert node not in self.inactive_nodes assert node not in self.inactive_nodes
...@@ -70,6 +104,8 @@ class ResultEquivalenceTracker(object): ...@@ -70,6 +104,8 @@ class ResultEquivalenceTracker(object):
self.inactive_nodes.add(node) self.inactive_nodes.add(node)
def on_import(self, env, node): def on_import(self, env, node):
self.event_list.append(Event('import', node))
#print 'NEW NODE', node, id(node) #print 'NEW NODE', node, id(node)
assert node not in self.active_nodes assert node not in self.active_nodes
self.active_nodes.add(node) self.active_nodes.add(node)
...@@ -93,6 +129,7 @@ class ResultEquivalenceTracker(object): ...@@ -93,6 +129,7 @@ class ResultEquivalenceTracker(object):
def on_change_input(self, env, node, i, r, new_r, reason=None): def on_change_input(self, env, node, i, r, new_r, reason=None):
#print 'CHANGE by', reason, 'to use', new_r, type(new_r) #print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self.event_list.append(Event('change', node, reason=str(reason), idx=i))
self.reasons.setdefault(new_r, []) self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, []) self.replaced_by.setdefault(new_r, [])
...@@ -291,7 +328,6 @@ class OptCheckLinker(OpWiseCLinker): ...@@ -291,7 +328,6 @@ class OptCheckLinker(OpWiseCLinker):
# because the incorrect result detected here will cause # because the incorrect result detected here will cause
# subsequent outputs to be incorrect. # subsequent outputs to be incorrect.
raise Exception("OptCheckFailure") raise Exception("OptCheckFailure")
print >> sys.stderr, 'OptCheck PASS'
if 0: #OLD CODE if 0: #OLD CODE
#print out the summary of the first problematic equivalence group #print out the summary of the first problematic equivalence group
...@@ -321,7 +357,9 @@ NODEFAULT = ['NODEFAULT'] ...@@ -321,7 +357,9 @@ NODEFAULT = ['NODEFAULT']
class OptCheckFunctionMaker(FunctionMaker): class OptCheckFunctionMaker(FunctionMaker):
def __init__(self, inputs, outputs, optimizer, def __init__(self, inputs, outputs, optimizer,
accept_inplace = False, function_builder = Function): chances_for_optimizer_to_screw_up = 10,
accept_inplace = False,
function_builder = Function):
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
...@@ -350,17 +388,39 @@ class OptCheckFunctionMaker(FunctionMaker): ...@@ -350,17 +388,39 @@ class OptCheckFunctionMaker(FunctionMaker):
expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], []) expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], [])
# make the env # make the env
env, additional_outputs, equivalence_tracker = optcheck_env(expanded_inputs, outputs, accept_inplace) for i in xrange(chances_for_optimizer_to_screw_up):
self.env = env env, additional_outputs, equivalence_tracker = optcheck_env(expanded_inputs, outputs, accept_inplace)
env.equivalence_tracker = equivalence_tracker
# optimize the env
optimizer(env)
if i:
li = env.equivalence_tracker.event_list
l0 = env0.equivalence_tracker.event_list
if li != l0 :
print >> sys.stderr, "WARNING: Optimization process is unstable"
for j in xrange(max(len(li), len(l0))):
if li[j] != l0[j]:
print >> sys.stderr, "* ", j
print >> sys.stderr, " ", str(li[j]) if j < len(li) else '-'
print >> sys.stderr, " ", str(l0[j]) if j < len(l0) else '-'
else:
pass
linker = OptCheckLinker() print >> sys.stderr, "EXITING"
sys.exit(1)
break
else:
print >> sys.stdout, "OPTCHECK: optimization", i, "of", len(li), "events was stable."
else:
env0 = env
# optimize the env
optimizer(env)
env.equivalence_tracker = equivalence_tracker del env0
self.env = env
#equivalence_tracker.printstuff() #equivalence_tracker.printstuff()
linker = OptCheckLinker()
#the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer. #the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer.
no_borrow = [output for output, spec in zip(env.outputs, outputs+additional_outputs) if not spec.borrow] no_borrow = [output for output, spec in zip(env.outputs, outputs+additional_outputs) if not spec.borrow]
...@@ -487,11 +547,14 @@ class OptCheck(Mode): ...@@ -487,11 +547,14 @@ class OptCheck(Mode):
# function_module.function # function_module.function
def function_maker(self, i,o,m, *args, **kwargs): def function_maker(self, i,o,m, *args, **kwargs):
assert m is self assert m is self
return OptCheckFunctionMaker(i, o, self.optimizer, *args, **kwargs) return OptCheckFunctionMaker(i, o, self.optimizer,
def __init__(self, optimizer='fast_run'): chances_for_optimizer_to_screw_up=self.stability_patience,
*args, **kwargs)
def __init__(self, optimizer='fast_run', stability_patience=10):
super(OptCheck, self).__init__( super(OptCheck, self).__init__(
optimizer=optimizer, optimizer=optimizer,
linker=OptCheckLinker) linker=OptCheckLinker)
self.stability_patience = stability_patience
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论