提交 209ec94b authored 作者: James Bergstra's avatar James Bergstra

several bugfixes to DebugMode

上级 5acbca53
...@@ -86,9 +86,13 @@ class BadOptimization(DebugModeError): ...@@ -86,9 +86,13 @@ class BadOptimization(DebugModeError):
self.old_graph = old_graph self.old_graph = old_graph
self.new_graph = new_graph self.new_graph = new_graph
def __str__(self):
return self.str_diagnostic()
def str_diagnostic(self): def str_diagnostic(self):
"""Return a pretty multiline string representating the cause of the exception""" """Return a pretty multiline string representating the cause of the exception"""
sio = StringIO() sio = StringIO()
print >> sio, "BadOptimization Error", super(BadOptimization, self).__str__()
print >> sio, " Result: id", id(self.new_r), self.new_r print >> sio, " Result: id", id(self.new_r), self.new_r
print >> sio, " Op", self.new_r.owner print >> sio, " Op", self.new_r.owner
print >> sio, " Value Type:", type(self.new_r_val) print >> sio, " Value Type:", type(self.new_r_val)
...@@ -117,8 +121,8 @@ class BadDestroyMap(DebugModeError): ...@@ -117,8 +121,8 @@ class BadDestroyMap(DebugModeError):
print >> sio, " destroy_map:", getattr(self.node.op, 'destroy_map', {}) print >> sio, " destroy_map:", getattr(self.node.op, 'destroy_map', {})
print >> sio, " changed input idx:", self.idx print >> sio, " changed input idx:", self.idx
print >> sio, " changed input type:", self.node.inputs[self.idx].type print >> sio, " changed input type:", self.node.inputs[self.idx].type
print >> sio, " old val:", self.old_val print >> sio, " repr (old val):", repr(self.old_val)
print >> sio, " new val:", self.new_val print >> sio, " repr (new val):", repr(self.new_val)
print >> sio, "" print >> sio, ""
print >> sio, " Hint: this can be caused by a deficient values_eq_enough() or __eq__() implementation that compares node input values" print >> sio, " Hint: this can be caused by a deficient values_eq_enough() or __eq__() implementation that compares node input values"
return sio.getvalue() return sio.getvalue()
...@@ -225,6 +229,92 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes): ...@@ -225,6 +229,92 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes):
else: else:
raise BadDestroyMap(node, r_idx, r_vals[r], storage_map[r][0]) raise BadDestroyMap(node, r_idx, r_vals[r], storage_map[r][0])
def _lessbroken_deepcopy(a):
if type(a) is numpy.ndarray:
rval = numpy.array(a, copy=True, dtype=a.dtype)
else:
rval = copy.deepcopy(a)
assert type(rval) == type(a)
if isinstance(rval, numpy.ndarray):
assert rval.dtype == a.dtype
return rval
def _find_bad_optimizations0(order, reasons, r_vals):
"""Use a simple algorithm to find broken optimizations. This algorithm is simple to
understand, but sometimes when there's a problem it identifies the wrong optimization as
the culprit.
"""
# iterate over results looking for values that don't match the values of the
# results they replaced. This is the sign of a broken optimization.
for i, node in enumerate(order):
for new_r in node.outputs:
for reason, r, old_graph_str, new_graph_str in reasons[new_r]:
problem = False
#check if the value for new_r doesn't match the value for r
new_r_val = r_vals[new_r]
r_val = r_vals[r]
assert r.type == new_r.type
if not r.type.values_eq_enough(r_val, new_r_val):
raise BadOptimization(old_r=r,
new_r=new_r,
old_r_val=r_val,
new_r_val=new_r_val,
reason=reason,
old_graph=old_graph_str,
new_graph=new_graph_str)
def _find_bad_optimizations1(order, reasons, r_vals):
# iterate over results looking for values that don't match the values of the
# results they replaced. This is the sign of a broken optimization.
#identify sets of results that are supposed to be equivalent
equivalence_sets = {}
program_position = {} #node -> order idx
for i, node in enumerate(order):
program_position[node] = i
for new_r in node.outputs:
equivalence_sets.setdefault(new_r, set([new_r]))
for reason, r, old_graph_str, new_graph_str in reasons[new_r]:
equivalence_sets[new_r].update(equivalence_sets.setdefault(r, set([r])))
for er in equivalence_sets[r]:
equivalence_sets[er] = equivalence_sets[new_r]
#identify equivalence sets that are broken
equivalence_sets_broken = {} #id(set) -> Bool
there_is_a_problem = False
for r, r_equiv in equivalence_sets.iteritems():
if id(r_equiv) not in equivalence_sets_broken:
equivalence_sets_broken[id(r_equiv)] = False
#loop over the results in the set comparing them to be equal enough
re0 = None
for re in r_equiv:
if re0:
new_r_val = r_vals[re]
r_val = r_vals[re0]
assert re.type == re0.type
if not re.type.values_eq_enough(r_val, new_r_val):
equivalence_sets_broken[id(r_equiv)] = True
there_is_a_problem = True
re0 = re
if there_is_a_problem:
# which broken equivalence set has the earliest-occurring element?
first_broken_set = None
for i, node in enumerate(order):
for r in node.outputs:
r_equiv = equivalence_sets[r]
if equivalence_sets_broken[id(r_equiv)]:
first_broken_set = r_equiv
#TODO finish this to produce good diagnostic information
print first_broken_set
raise Exception('broken')
class _EnvEvent(object): class _EnvEvent(object):
"""A record of an event in the life of an Env. """A record of an event in the life of an Env.
...@@ -425,12 +515,12 @@ class _Linker(gof.link.LocalLinker): ...@@ -425,12 +515,12 @@ class _Linker(gof.link.LocalLinker):
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env env = self.env
input_storage_ = input_storage
output_storage_ = output_storage
#order = env.toposort() #order = env.toposort()
#Compute a topological ordering that IGNORES the destroy_map of destructive Ops. #Compute a topological ordering that IGNORES the destroy_map of destructive Ops.
#This will be OK, because every thunk is evaluated on a copy of its input. #This will be OK, because every thunk is evaluated on a copy of its input.
# If the copy.copy function produces an object that is aliased to the original one,
# then this evaluation mode will not work. It works for ndarrays.
order_outputs = copy.copy(env.equivalence_tracker.all_results_ever) order_outputs = copy.copy(env.equivalence_tracker.all_results_ever)
order_outputs.reverse() order_outputs.reverse()
order = graph.io_toposort(env.inputs, order_outputs) order = graph.io_toposort(env.inputs, order_outputs)
...@@ -440,7 +530,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -440,7 +530,8 @@ class _Linker(gof.link.LocalLinker):
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage) input_storage, output_storage, storage_map = link.map_storage(env, order,
input_storage_, output_storage_)
thunks_py = [] #python thunks thunks_py = [] #python thunks
thunks_c = [] #c thunks thunks_c = [] #c thunks
...@@ -521,11 +612,11 @@ class _Linker(gof.link.LocalLinker): ...@@ -521,11 +612,11 @@ class _Linker(gof.link.LocalLinker):
# transfer the initial values from the storage_map to the r_vals # transfer the initial values from the storage_map to the r_vals
for r in storage_map: for r in storage_map:
if storage_map[r][0] is not None: if (r.owner is None):
if r.owner is not None: if (storage_map[r][0] is None):
# DEBUG raise Exception('Missing input', r)
print r, storage_map[r], type(storage_map[r]), id(storage_map[r]) if not r.type.is_valid_value(storage_map[r][0]):
assert r.owner is None raise InvalidValueError(r, storage_map[r][0])
r_vals[r] = storage_map[r][0] r_vals[r] = storage_map[r][0]
storage_map[r][0] = None storage_map[r][0] = None
##### #####
...@@ -541,13 +632,17 @@ class _Linker(gof.link.LocalLinker): ...@@ -541,13 +632,17 @@ class _Linker(gof.link.LocalLinker):
# put a copy of each input into the storage_map # put a copy of each input into the storage_map
for r in node.inputs: for r in node.inputs:
storage_map[r][0] = copy.copy(r_vals[r]) assert isinstance(r, gof.Result)
assert r in r_vals
storage_map[r][0] = _lessbroken_deepcopy(r_vals[r])
if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0])
thunk_py() thunk_py()
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set) _check_inputs(node, storage_map, r_vals, dr_vals, active_order_set)
#retrieve a copy of each output from the storage_map #retrieve each output from the storage_map
for r in node.outputs: for r in node.outputs:
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0]) raise InvalidValueError(r, storage_map[r][0])
...@@ -561,7 +656,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -561,7 +656,7 @@ class _Linker(gof.link.LocalLinker):
for r in node.inputs: for r in node.inputs:
# TODO: we only need to overwrite the non-destroyed inputs # TODO: we only need to overwrite the non-destroyed inputs
storage_map[r][0] = copy.copy(r_vals[r]) storage_map[r][0] = _lessbroken_deepcopy(r_vals[r])
thunk_c() thunk_c()
...@@ -584,27 +679,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -584,27 +679,7 @@ class _Linker(gof.link.LocalLinker):
except: except:
raise_with_op(node) raise_with_op(node)
# iterate over results looking for values that don't match the values of the _find_bad_optimizations0(order, env.equivalence_tracker.reasons, r_vals)
# results they replaced. This is the sign of a broken optimization.
for i, node in enumerate(order):
for new_r in node.outputs:
for reason, r, old_graph_str, new_graph_str in env.equivalence_tracker.reasons[new_r]:
problem = False
#check if the value for new_r doesn't match the value for r
new_r_val = r_vals[new_r]
r_val = r_vals[r]
assert r.type == new_r.type
if not r.type.values_eq_enough(r_val, new_r_val):
raise BadOptimization(old_r=r,
new_r=new_r,
old_r_val=r_val,
new_r_val=new_r_val,
reason=reason,
old_graph=old_graph_str,
new_graph=new_graph_str)
##### #####
# Postcondition: the input and output results are in the storage map, nothing more # Postcondition: the input and output results are in the storage map, nothing more
...@@ -629,6 +704,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -629,6 +704,7 @@ class _Linker(gof.link.LocalLinker):
# if an input was destroyed, the destroyed value should be returned # if an input was destroyed, the destroyed value should be returned
for r in dr_vals: for r in dr_vals:
assert dr_vals[r][0] is not None
if r.owner is None: if r.owner is None:
assert r in env.inputs assert r in env.inputs
#HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION TOOK PLACE #HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION TOOK PLACE
...@@ -645,6 +721,9 @@ class _Linker(gof.link.LocalLinker): ...@@ -645,6 +721,9 @@ class _Linker(gof.link.LocalLinker):
#print output_storage #print output_storage
#print dr_vals #print dr_vals
#print storage_map #print storage_map
for r in storage_map:
if (r.owner is None):
assert storage_map[r][0] is not None
############### ###############
# Done f # Done f
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论