提交 9b3fb435 authored 作者: James Bergstra's avatar James Bergstra

minor code changes and much documentation for DebugMode

上级 4eebb1c9
""" Provides `DebugMode` """Provides `DebugMode`, an evaluation mode for debugging theano internals."""
__docformat__ = "restructuredtext en"
"""
import time, copy, sys import time, copy, sys
from StringIO import StringIO from StringIO import StringIO
from .. import gof import numpy
from .. import gof
from ..gof import Env, graph, utils, link from ..gof import Env, graph, utils, link
from ..gof.link import WrapLinkerMany, raise_with_op from ..gof.link import WrapLinkerMany, raise_with_op
from ..gof.cutils import run_cthunk from ..gof.cutils import run_cthunk
from ..gof.cc import OpWiseCLinker, CLinker from ..gof.cc import OpWiseCLinker, CLinker
import numpy
from ..compile.function_module import (FunctionMaker, from ..compile.function_module import (FunctionMaker,
Function, Function,
infer_reuse_pattern, infer_reuse_pattern,
...@@ -20,96 +18,138 @@ from ..compile.function_module import (FunctionMaker, ...@@ -20,96 +18,138 @@ from ..compile.function_module import (FunctionMaker,
SymbolicInputKit, SymbolicInputKit,
SymbolicOutput, SymbolicOutput,
Supervisor) Supervisor)
from ..compile.mode import Mode, register_mode
class DebugModeError(Exception): class DebugModeError(Exception):
"""Generic Exception raised to indicate an internal theano problem"""
pass pass
class BadClinkerOutput(DebugModeError): class BadClinkerOutput(DebugModeError):
"""Exception: a c implementation and python implementation don't agree""" """Exception: an Op's c_code and perform implementations don't agree."""
r = None r = None
"""TODO""" """The `Result` instance for which conflicting values were computed"""
a = None val_py = None
"""TODO""" """The value computed by `r.owner.op.perform`"""
b = None val_c = None
"""TODO""" """The value computed by `r.owner.op.c_code`"""
def __init__(self, r, a, b): def __init__(self, r, val_py, val_c):
"""Initialize members""" """Initialize members"""
super(BadClinkerOutput, self).__init__() super(BadClinkerOutput, self).__init__()
self.r = r self.r = r
self.a = a self.val_py = val_py
self.b = b self.val_c = val_c
def offending_op(self):
return type(self.r.owner.op)
class BadOptimization(DebugModeError): class BadOptimization(DebugModeError):
"""Exception: some result and its substitute take different runtime values.""" """Exception: some result and its substitute take different runtime values.
"""
new_r = None new_r = None
"""TODO""" """A `Result` instance that took a different value from `old_r`, but which replaced `old_r`."""
r_val = None old_r = None
"""TODO""" """A `Result` instance that was replaced by `new_r`."""
old_r_val = None
"""The value computed for `old_r`."""
new_r_val = None new_r_val = None
"""TODO""" """The value computed for `new_r`."""
reasons = [] reason = None
"""TODO""" """An object that indicates why old_r was turned into new_r.
snapshots = [] Convention is that this is the name of the optimization that requested the replacement.
"""TODO""" """
def __init__(self, new_r, r_val, new_r_val, reasons, snapshots): old_graph = ""
"""A multiline string representation of the graph leading to old_r, at the time of the replacement."""
new_graph = ""
"""A multiline string representation of the graph leading to new_r, at the time of the replacement."""
def __init__(self, old_r, new_r, old_r_val, new_r_val, reason, old_graph, new_graph):
"""Initialize members""" """Initialize members"""
super(BadOptimization, self).__init__() super(BadOptimization, self).__init__()
self.old_r = old_r
self.new_r = new_r self.new_r = new_r
self.r_val = r_val self.old_r_val = old_r_val
self.new_r_val = new_r_val self.new_r_val = new_r_val
self.reasons = reasons self.reason = reason
self.snapshots = snapshots self.old_graph = old_graph
self.new_graph = new_graph
#def __str__(self):
#return self.str_diagnostic() #debatable...
def str_diagnostic(self): def str_diagnostic(self):
"""TODO: what does this mean? How to interpret? """ """Return a pretty multiline string representating the cause of the exception"""
sio = StringIO() sio = StringIO()
print >> sio, " Result:", id(self.new_r), self.new_r print >> sio, " Result: id", id(self.new_r), self.new_r
print >> sio, " Op", self.new_r.owner print >> sio, " Op", self.new_r.owner
print >> sio, " Value Type:", type(self.new_r_val) print >> sio, " Value Type:", type(self.new_r_val)
print >> sio, " Old Value: ", self.r_val print >> sio, " Old Value: ", self.old_r_val
print >> sio, " Value: ", self.new_r_val print >> sio, " New Value: ", self.new_r_val
print >> sio, " Reason: ", [(str(reason), id(old_r)) for reason, old_r in self.reasons[self.new_r]] print >> sio, " Reason: ", str(self.reason)
print >> sio, " Snapshots:" print >> sio, " Old Graph:"
for s in self.snapshots[self.new_r]: print >> sio, self.old_graph
print >> sio, " BEFORE" print >> sio, " New Graph:"
print >> sio, s[1] print >> sio, self.new_graph
print >> sio, " AFTER"
print >> sio, s[2]
return sio.getvalue() return sio.getvalue()
def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout): def _debugprint(r, prefix='', depth=-1, done=None, file=sys.stdout):
"""Print the graph ending at `a` to given depth.
:param r: Result instance
:param prefix: prefix to each line (typically some number of spaces)
:param depth: maximum recursion depth (Default -1 for unlimited).
:param done: set of Apply instances that have already been printed
:param file: file-like object to which to print
"""
if depth==0: if depth==0:
return return
done = set() if done is None else done done = set() if done is None else done
if hasattr(a, 'op'): if hasattr(r.owner, 'op'):
# this result is the output of computation,
# so just print out the apply
a = r.owner
print >> file, prefix, a.op, id(a) print >> file, prefix, a.op, id(a)
if id(a) not in done: if id(a) not in done:
done.add(id(a)) done.add(id(a))
for i in a.inputs: for i in a.inputs:
if i.owner: _debugprint(i, prefix+' ', depth=depth-1, done=done, file=file)
debugprint(i.owner, prefix+' ', depth=depth-1, done=done, file=file)
else:
print >> file, prefix+' ', i, id(i)
else: else:
print >> file, prefix+' ', a, id(a) #this is a result
print >> file, prefix, r, id(r)
return file return file
class Event(object): class _EnvEvent(object):
"""A record of an event in the life of an Env.
The __eq__ function is important here, as it is the basis for comparing optimization runs.
"""
kind = ""
"""One of 'import', 'change', 'prune'"""
node = None
"""Either 'output' or an Apply instance"""
op = None
"""Either 'output' or an Op instance"""
idx = None
"""change events involve an position index of the input result"""
reason = None
"""change events sometimes have a reason"""
def __init__(self, kind, node, idx=None, reason=None): def __init__(self, kind, node, idx=None, reason=None):
self.kind = kind self.kind = kind
if node == 'output': if node == 'output':
...@@ -134,6 +174,8 @@ class Event(object): ...@@ -134,6 +174,8 @@ class Event(object):
def __eq__(self, other): def __eq__(self, other):
rval = type(self) == type(other) rval = type(self) == type(other)
if rval: if rval:
# nodes are not compared because this comparison is supposed to be true for
# corresponding events that happen in different Env instances (different graphs)
for attr in ['kind', 'op', 'idx', 'reason']: for attr in ['kind', 'op', 'idx', 'reason']:
rval = rval and getattr(self, attr) == getattr(other, attr) rval = rval and getattr(self, attr) == getattr(other, attr)
return rval return rval
...@@ -141,7 +183,33 @@ class Event(object): ...@@ -141,7 +183,33 @@ class Event(object):
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
class ResultEquivalenceTracker(object): class _ResultEquivalenceTracker(object):
"""A Env Feature that keeps tabs on an Env and tries to detect problems."""
env = None
"""WRITEME"""
equiv = None
"""WRITEME"""
active_nodes = None
"""WRITEME"""
inactive_nodes = None
"""WRITEME"""
all_results_ever = None
"""WRITEME"""
reasons = None
"""WRITEME"""
replaced_by = None
"""WRITEME"""
event_list = None
"""WRITEME"""
def __init__(self): def __init__(self):
self.env = None self.env = None
...@@ -154,7 +222,6 @@ class ResultEquivalenceTracker(object): ...@@ -154,7 +222,6 @@ class ResultEquivalenceTracker(object):
self.all_results_ever = [] self.all_results_ever = []
self.reasons = {} self.reasons = {}
self.replaced_by = {} self.replaced_by = {}
self.snapshots = {}
self.event_list = [] self.event_list = []
def on_detach(self, env): def on_detach(self, env):
...@@ -162,7 +229,7 @@ class ResultEquivalenceTracker(object): ...@@ -162,7 +229,7 @@ class ResultEquivalenceTracker(object):
self.env = None self.env = None
def on_prune(self, env, node): def on_prune(self, env, node):
self.event_list.append(Event('prune', node)) self.event_list.append(_EnvEvent('prune', node))
#print 'PRUNING NODE', node, id(node) #print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes assert node in self.active_nodes
assert node not in self.inactive_nodes assert node not in self.inactive_nodes
...@@ -170,7 +237,7 @@ class ResultEquivalenceTracker(object): ...@@ -170,7 +237,7 @@ class ResultEquivalenceTracker(object):
self.inactive_nodes.add(node) self.inactive_nodes.add(node)
def on_import(self, env, node): def on_import(self, env, node):
self.event_list.append(Event('import', node)) self.event_list.append(_EnvEvent('import', node))
#print 'NEW NODE', node, id(node) #print 'NEW NODE', node, id(node)
assert node not in self.active_nodes assert node not in self.active_nodes
...@@ -187,26 +254,30 @@ class ResultEquivalenceTracker(object): ...@@ -187,26 +254,30 @@ class ResultEquivalenceTracker(object):
self.all_results_ever.append(r) self.all_results_ever.append(r)
self.reasons.setdefault(r, []) self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, []) self.replaced_by.setdefault(r, [])
self.snapshots.setdefault(r, [])
for r in node.inputs: for r in node.inputs:
self.reasons.setdefault(r, []) self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, []) self.replaced_by.setdefault(r, [])
self.snapshots.setdefault(r, [])
def on_change_input(self, env, node, i, r, new_r, reason=None): def on_change_input(self, env, node, i, r, new_r, reason=None):
#print 'CHANGE by', reason, 'to use', new_r, type(new_r) #print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self.event_list.append(Event('change', node, reason=str(reason), idx=i)) self.event_list.append(_EnvEvent('change', node, reason=str(reason), idx=i))
self.reasons.setdefault(new_r, []) self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, []) self.replaced_by.setdefault(new_r, [])
self.snapshots.setdefault(new_r, [])
if (reason, r) not in self.reasons[new_r]: append_reason = True
self.reasons[new_r].append((reason, r)) for tup in self.reasons[new_r]:
if tup[0] == reason and tup[1] is r:
append_reason = False
if append_reason:
# N.B. compute the _debugprint now, because future optimizations will change the
# graph
self.reasons[new_r].append((reason
, r
, _debugprint(r, prefix=' ', depth=6, file=StringIO()).getvalue()
, _debugprint(new_r, prefix=' ', depth=6, file=StringIO()).getvalue()))
self.replaced_by[r].append((reason, new_r)) self.replaced_by[r].append((reason, new_r))
self.snapshots[new_r].append((
reason,
debugprint(r.owner, prefix=' ', depth=6, file=StringIO()).getvalue(),
debugprint(new_r.owner,prefix=' ', depth=6, file=StringIO()).getvalue()))
if r in self.equiv: if r in self.equiv:
r_set = self.equiv[r] r_set = self.equiv[r]
...@@ -240,13 +311,13 @@ class ResultEquivalenceTracker(object): ...@@ -240,13 +311,13 @@ class ResultEquivalenceTracker(object):
for e in self.equiv[key]: for e in self.equiv[key]:
print ' ', e print ' ', e
def optcheck_env(input_specs, output_specs, accept_inplace = False): def _optcheck_env(input_specs, output_specs, accept_inplace = False):
orig_inputs = [spec.result for spec in input_specs] orig_inputs = [spec.result for spec in input_specs]
updates = [spec.update for spec in input_specs if spec.update] updates = [spec.update for spec in input_specs if spec.update]
orig_outputs = [spec.result for spec in output_specs] + updates orig_outputs = [spec.result for spec in output_specs] + updates
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs) inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs)
equivalence_tracker = ResultEquivalenceTracker() equivalence_tracker = _ResultEquivalenceTracker()
env = gof.env.Env(inputs, outputs, env = gof.env.Env(inputs, outputs,
features=[equivalence_tracker, features=[equivalence_tracker,
gof.DestroyHandler(do_imports_on_attach=False)]) gof.DestroyHandler(do_imports_on_attach=False)])
...@@ -261,7 +332,7 @@ def optcheck_env(input_specs, output_specs, accept_inplace = False): ...@@ -261,7 +332,7 @@ def optcheck_env(input_specs, output_specs, accept_inplace = False):
return env, map(SymbolicOutput, updates), equivalence_tracker return env, map(SymbolicOutput, updates), equivalence_tracker
class DebugModeLinker(gof.link.LocalLinker): class _Linker(gof.link.LocalLinker):
def __init__(self, maker): def __init__(self, maker):
super(gof.LocalLinker, self).__init__() super(gof.LocalLinker, self).__init__()
self.env = None self.env = None
...@@ -269,7 +340,7 @@ class DebugModeLinker(gof.link.LocalLinker): ...@@ -269,7 +340,7 @@ class DebugModeLinker(gof.link.LocalLinker):
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env: if self.env is not None and self.env is not env:
assert type(self) is DebugModeLinker assert type(self) is _Linker
return type(self)(self.env, self.maker).accept(env, no_recycling) return type(self)(self.env, self.maker).accept(env, no_recycling)
self.env = env self.env = env
self.no_recycling = no_recycling self.no_recycling = no_recycling
...@@ -386,7 +457,7 @@ class DebugModeLinker(gof.link.LocalLinker): ...@@ -386,7 +457,7 @@ class DebugModeLinker(gof.link.LocalLinker):
if r in r_vals: if r in r_vals:
# r has been constant-folded # r has been constant-folded
if not r.type.values_eq_enough(r_vals[r], storage_map[r][0]): if not r.type.values_eq_enough(r_vals[r], storage_map[r][0]):
raise Exception('BadConstantFold', (r, r_vals[r], raise DebugModeError('BadConstantFold', (r, r_vals[r],
storage_map[r][0])) #TODO: make a proper exception class for this storage_map[r][0])) #TODO: make a proper exception class for this
else: else:
...@@ -410,16 +481,24 @@ class DebugModeLinker(gof.link.LocalLinker): ...@@ -410,16 +481,24 @@ class DebugModeLinker(gof.link.LocalLinker):
# compares the version from thunk_py (in r_vals) # compares the version from thunk_py (in r_vals)
# to the version produced by thunk_c (in storage_map) # to the version produced by thunk_c (in storage_map)
if not r.type.values_eq_enough(r_vals[r], storage_map[r][0]): if not r.type.values_eq_enough(r_vals[r], storage_map[r][0]):
raise BadClinkerOutput(r, r_vals[r], storage_map[r][0]) raise BadClinkerOutput(r, val_py=r_vals[r], val_c=storage_map[r][0])
except: except:
raise_with_op(node) raise_with_op(node)
# iterate over results looking for values that don't match the values of the # iterate over results looking for values that don't match the values of the
# results they replaced. This is the sign of a broken optimization. # results they replaced. This is the sign of a broken optimization.
# A basic premise of how theano works is that every node that is replaced during optimization should compute the same thing as its replacement.
# Normally such replacements run instead of the originals.
# This Mode runs the original and the replacement, and then checks that they both compute the
# same thing.
# If their values are different, the optimization that created the replacement is probably broken.
for i, node in enumerate(order): for i, node in enumerate(order):
for new_r in node.outputs: for new_r in node.outputs:
for reason, r in env.equivalence_tracker.reasons[new_r]: for reason, r, old_graph_str, new_graph_str in env.equivalence_tracker.reasons[new_r]:
problem = False problem = False
#check if the value for new_r doesn't match the value for r #check if the value for new_r doesn't match the value for r
...@@ -428,17 +507,21 @@ class DebugModeLinker(gof.link.LocalLinker): ...@@ -428,17 +507,21 @@ class DebugModeLinker(gof.link.LocalLinker):
assert r.type == new_r.type assert r.type == new_r.type
if not r.type.values_eq_enough(r_val, new_r_val): if not r.type.values_eq_enough(r_val, new_r_val):
raise BadOptimization(new_r, r_val, new_r_val, raise BadOptimization(old_r=r,
env.equivalence_tracker.reasons, new_r=new_r,
env.equivalence_tracker.snapshots) old_r_val=r_val,
new_r_val=new_r_val,
reason=reason,
old_graph=old_graph_str,
new_graph=new_graph_str)
f.allow_gc = True f.allow_gc = True
return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks_py, order thunks_py, order
NODEFAULT = ['NODEFAULT'] _NODEFAULT = ['NODEFAULT']
class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper functions class _Maker(FunctionMaker): #inheritance buys a few helper functions
verbose = 0 verbose = 0
"""Verbosity level of compile-time and run-time checks. (Default 0: silent)""" """Verbosity level of compile-time and run-time checks. (Default 0: silent)"""
...@@ -474,7 +557,7 @@ class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper func ...@@ -474,7 +557,7 @@ class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper func
# make the env # make the env
for i in xrange(mode.stability_patience): for i in xrange(mode.stability_patience):
env, additional_outputs, equivalence_tracker = optcheck_env(expanded_inputs, outputs, accept_inplace) env, additional_outputs, equivalence_tracker = _optcheck_env(expanded_inputs, outputs, accept_inplace)
env.equivalence_tracker = equivalence_tracker env.equivalence_tracker = equivalence_tracker
# optimize the env # optimize the env
optimizer(env) optimizer(env)
...@@ -505,7 +588,7 @@ class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper func ...@@ -505,7 +588,7 @@ class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper func
self.env = env self.env = env
#equivalence_tracker.printstuff() #equivalence_tracker.printstuff()
linker = DebugModeLinker(self) linker = _Linker(self)
#the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer. #the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer.
...@@ -565,7 +648,7 @@ class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper func ...@@ -565,7 +648,7 @@ class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper func
input_storage += [default[i].storage for i in indices] input_storage += [default[i].storage for i in indices]
else: else:
raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default) raise ValueError('Not enough storage for SymbolicInputKit', input, indices, default)
default = NODEFAULT default = _NODEFAULT
else: else:
input_storage += [[None] for i in indices] input_storage += [[None] for i in indices]
else: else:
...@@ -581,7 +664,7 @@ class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper func ...@@ -581,7 +664,7 @@ class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper func
# Even though a SymbolicInputKit represents more than one input, # Even though a SymbolicInputKit represents more than one input,
# we still only have one entry for the defaults list. # we still only have one entry for the defaults list.
if isinstance(input, SymbolicInputKit): if isinstance(input, SymbolicInputKit):
if default is NODEFAULT: if default is _NODEFAULT:
_defaults.append((False, False, None)) _defaults.append((False, False, None))
elif default is None: elif default is None:
_defaults.append((True, True, None)) _defaults.append((True, True, None))
...@@ -620,31 +703,36 @@ class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper func ...@@ -620,31 +703,36 @@ class DebugModeFunctionMaker(FunctionMaker): #inheritance buys a few helper func
from ..compile.mode import Mode, register_mode
class DebugMode(Mode): class DebugMode(Mode):
"""Evaluation Mode that detects optimization errors. """Evaluation Mode that detects internal theano errors.
This mode catches several kinds of internal error:
- inconsistent c_code and perform implementations (see `BadClinkerOutput`)
- incorrect optimizations (see `BadOptimization`)
- stochastic optimization ordering
A basic premise of how theano works is that every node that is replaced during optimization should compute the same thing as its replacement. If there are no internal errors, this mode behaves like FAST_RUN or FAST_COMPILE, but takes
a little longer and uses more memory.
Normally such replacements run instead of the originals. If there are internal errors, this mode will raise Exceptions (see above) and write
This Mode runs the original and the replacement, and then checks that they both compute the diagnostic information to a file.
same thing.
If their values are different, the optimization that created the replacement is probably
broken.
""" """
# This function will be used to create a FunctionMaker in # This function will be used to create a FunctionMaker in
# function_module.function # function_module.function
def function_maker(self, i,o,m, *args, **kwargs): def function_maker(self, i,o,m, *args, **kwargs):
assert m is self assert m is self
return DebugModeFunctionMaker(i, o, self.optimizer, self, *args, **kwargs) return _Maker(i, o, self.optimizer, self, *args, **kwargs)
def __init__(self, def __init__(self,
optimizer='fast_run', optimizer='fast_run',
stability_patience=10, stability_patience=10,
check_c_code=True): check_c_code=True):
super(DebugMode, self).__init__( super(DebugMode, self).__init__(
optimizer=optimizer, optimizer=optimizer,
linker=DebugModeLinker) linker=_Linker)
self.stability_patience = stability_patience self.stability_patience = stability_patience
self.check_c_code = check_c_code self.check_c_code = check_c_code
register_mode('DEBUG_MODE',DebugMode(optimizer='fast_run')) register_mode('DEBUG_MODE',DebugMode(optimizer='fast_run'))
...@@ -214,7 +214,7 @@ def test_badoptimization(): ...@@ -214,7 +214,7 @@ def test_badoptimization():
3, 3,
numpy.asarray([[0.,1.,2.],[3.,4.,5.],[6.,7.,8.]])) numpy.asarray([[0.,1.,2.],[3.,4.,5.],[6.,7.,8.]]))
except debugmode.BadOptimization, e: except debugmode.BadOptimization, e:
assert str(e.reasons[e.new_r][0][0]) == 'insert_broken_csc' assert str(e.reason) == 'insert_broken_csc'
return #TEST PASS return #TEST PASS
assert False assert False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论