提交 ec7877c4 authored 作者: James Bergstra's avatar James Bergstra

more focused output from debugmode

上级 21e3f3f6
...@@ -63,6 +63,7 @@ class ResultEquivalenceTracker(object): ...@@ -63,6 +63,7 @@ class ResultEquivalenceTracker(object):
self.env = env self.env = env
self.all_results_ever = [] self.all_results_ever = []
self.reasons = {} self.reasons = {}
self.replaced_by = {}
self.snapshots = {} self.snapshots = {}
def on_detach(self, env): def on_detach(self, env):
...@@ -91,23 +92,26 @@ class ResultEquivalenceTracker(object): ...@@ -91,23 +92,26 @@ class ResultEquivalenceTracker(object):
self.equiv[r] = set([r]) self.equiv[r] = set([r])
self.all_results_ever.append(r) self.all_results_ever.append(r)
self.reasons.setdefault(r, []) self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])
self.snapshots.setdefault(r, []) self.snapshots.setdefault(r, [])
for r in node.inputs: for r in node.inputs:
self.reasons.setdefault(r, []) self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])
self.snapshots.setdefault(r, []) self.snapshots.setdefault(r, [])
def on_change_input(self, env, node, i, r, new_r, reason=None): def on_change_input(self, env, node, i, r, new_r, reason=None):
#print 'CHANGE by', reason, 'to use', new_r, type(new_r) #print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self.reasons.setdefault(new_r, []) self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, [])
self.snapshots.setdefault(new_r, []) self.snapshots.setdefault(new_r, [])
if (reason, r) not in self.reasons[new_r]: if (reason, r) not in self.reasons[new_r]:
self.reasons[new_r].append((reason, r)) self.reasons[new_r].append((reason, r))
self.replaced_by[r].append((reason, new_r))
self.snapshots[new_r].append(( self.snapshots[new_r].append((
reason, reason,
debugprint(r.owner, prefix=' ', depth=6, file=StringIO()).getvalue(), debugprint(r.owner, prefix=' ', depth=6, file=StringIO()).getvalue(),
debugprint(new_r.owner,prefix=' ', depth=6, file=StringIO()).getvalue())) debugprint(new_r.owner,prefix=' ', depth=6, file=StringIO()).getvalue()))
self.reasons[r].append(('replaced by', new_r))
if r in self.equiv: if r in self.equiv:
r_set = self.equiv[r] r_set = self.equiv[r]
...@@ -165,9 +169,16 @@ class OptCheckLinker(OpWiseCLinker): ...@@ -165,9 +169,16 @@ class OptCheckLinker(OpWiseCLinker):
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env env = self.env
#order = env.toposort() #order = env.toposort()
#Compute a topological ordering that IGNORES the destroy_map of destructive Ops.
#This will be OK, because every thunk is evaluated on a copy of its input.
# If the copy.copy function produces an object that is aliased to the original one,
# then this evaluation mode will not work. It works for ndarrays.
order_outputs = copy.copy(env.equivalence_tracker.all_results_ever) order_outputs = copy.copy(env.equivalence_tracker.all_results_ever)
order_outputs.reverse() order_outputs.reverse()
order = graph.io_toposort(env.inputs, order_outputs) order = graph.io_toposort(env.inputs, order_outputs)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage) input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage)
...@@ -235,32 +246,61 @@ class OptCheckLinker(OpWiseCLinker): ...@@ -235,32 +246,61 @@ class OptCheckLinker(OpWiseCLinker):
problematic = set() problematic = set()
r_vals = {} r_vals = {}
assert len(thunks) == len(order) assert len(thunks) == len(order)
for r in env.inputs:
r_vals[r] = copy.copy(storage_map[r][0])
# compute the value of all results
for i, (thunk, node) in enumerate(zip(thunks, order)): for i, (thunk, node) in enumerate(zip(thunks, order)):
thunk() thunk()
for r in node.outputs: for r in node.outputs:
r_set = env.equivalence_tracker.equiv[r] assert r not in r_vals
this_r_val = copy.copy(storage_map[r][0]) this_r_val = copy.copy(storage_map[r][0])
r_vals[r] = this_r_val r_vals[r] = this_r_val
assert this_r_val is not None
if id(r_set) not in equiv_vals:
#print 'get correct', r_set
equiv_vals[id(r_set)] = this_r_val
else:
correct_r_val = equiv_vals[id(r_set)]
# TODO: use r.type.val_cmp(correct_r_val, this_r_val)
# That function doesn't exist yet though..
if type(correct_r_val) != type(this_r_val):
problematic.add(r)
elif type(correct_r_val) is numpy.ndarray:
if not numpy.allclose(correct_r_val, this_r_val):
problematic.add(r)
else:
print 'Ignoring comparison of instances of', type(correct_r_val)
if problematic: # iterate over results looking for values that don't match the values of the
# results they replaced. This is the sign of a broken optimization.
for i, (thunk, node) in enumerate(zip(thunks, order)):
for new_r in node.outputs:
for reason, r in env.equivalence_tracker.reasons[new_r]:
problem = False
#check if the value for new_r doesn't match the value for r
new_r_val = r_vals[new_r]
r_val = r_vals[r]
if type(new_r_val) != type(r_val):
problem = True
elif type(new_r_val) is numpy.ndarray:
if not numpy.allclose(new_r_val, r_val):
problem = True
else:
print >> sys.stderr, 'WARNING: OptCheck Ignoring comparison of instances of', type(new_r_val)
if problem:
print "OPTCHECK FAILURE"
print " Result:", id(new_r), new_r
print " Op", new_r.owner
print " Value Type:", type(new_r_val)
print " Old Value: ", r_val
print " Value: ", new_r_val
print " Reason: ", [(str(reason), id(old_r)) for reason, old_r in env.equivalence_tracker.reasons[new_r]]
print " Snapshots:"
for s in env.equivalence_tracker.snapshots[new_r]:
print " BEFORE"
print s[1]
print " AFTER"
print s[2]
print ""
# There is no point in continuing to check for more problems,
# because the incorrect result detected here will cause
# subsequent outputs to be incorrect.
raise Exception("OptCheckFailure")
if 0:
#print out the summary of the first problematic equivalence group #print out the summary of the first problematic equivalence group
min_member = [] min_member = []
for problem_r in problematic: for problem_r in problematic:
...@@ -273,22 +313,7 @@ class OptCheckLinker(OpWiseCLinker): ...@@ -273,22 +313,7 @@ class OptCheckLinker(OpWiseCLinker):
problematic_set = min_member[0][1] problematic_set = min_member[0][1]
print "OPTCHECK FAILURE"
for r in problematic_set:
print " Result:", id(r), r
print " Op", r.owner
print " Value Type:", type(r_vals[r])
print " Value: ", r_vals[r]
print " Reason: ", [(str(reason), id(old_r)) for reason, old_r in env.equivalence_tracker.reasons[r]]
print " Snapshots:"
for s in env.equivalence_tracker.snapshots[r]:
print " BEFORE"
print s[1]
print " AFTER"
print s[2]
print ""
raise Exception("OptCheckFailure")
except: except:
raise_with_op(node) raise_with_op(node)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论