提交 d9eebdd5 authored 作者: James Bergstra's avatar James Bergstra

merge

...@@ -63,6 +63,7 @@ class ResultEquivalenceTracker(object): ...@@ -63,6 +63,7 @@ class ResultEquivalenceTracker(object):
self.env = env self.env = env
self.all_results_ever = [] self.all_results_ever = []
self.reasons = {} self.reasons = {}
self.replaced_by = {}
self.snapshots = {} self.snapshots = {}
def on_detach(self, env): def on_detach(self, env):
...@@ -91,23 +92,26 @@ class ResultEquivalenceTracker(object): ...@@ -91,23 +92,26 @@ class ResultEquivalenceTracker(object):
self.equiv[r] = set([r]) self.equiv[r] = set([r])
self.all_results_ever.append(r) self.all_results_ever.append(r)
self.reasons.setdefault(r, []) self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])
self.snapshots.setdefault(r, []) self.snapshots.setdefault(r, [])
for r in node.inputs: for r in node.inputs:
self.reasons.setdefault(r, []) self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])
self.snapshots.setdefault(r, []) self.snapshots.setdefault(r, [])
def on_change_input(self, env, node, i, r, new_r, reason=None): def on_change_input(self, env, node, i, r, new_r, reason=None):
#print 'CHANGE by', reason, 'to use', new_r, type(new_r) #print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self.reasons.setdefault(new_r, []) self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, [])
self.snapshots.setdefault(new_r, []) self.snapshots.setdefault(new_r, [])
if (reason, r) not in self.reasons[new_r]: if (reason, r) not in self.reasons[new_r]:
self.reasons[new_r].append((reason, r)) self.reasons[new_r].append((reason, r))
self.replaced_by[r].append((reason, new_r))
self.snapshots[new_r].append(( self.snapshots[new_r].append((
reason, reason,
debugprint(r.owner, prefix=' ', depth=6, file=StringIO()).getvalue(), debugprint(r.owner, prefix=' ', depth=6, file=StringIO()).getvalue(),
debugprint(new_r.owner,prefix=' ', depth=6, file=StringIO()).getvalue())) debugprint(new_r.owner,prefix=' ', depth=6, file=StringIO()).getvalue()))
self.reasons[r].append(('replaced by', new_r))
if r in self.equiv: if r in self.equiv:
r_set = self.equiv[r] r_set = self.equiv[r]
...@@ -165,9 +169,16 @@ class OptCheckLinker(OpWiseCLinker): ...@@ -165,9 +169,16 @@ class OptCheckLinker(OpWiseCLinker):
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler = None, input_storage = None, output_storage = None):
env = self.env env = self.env
#order = env.toposort() #order = env.toposort()
#Compute a topological ordering that IGNORES the destroy_map of destructive Ops.
#This will be OK, because every thunk is evaluated on a copy of its input.
# If the copy.copy function produces an object that is aliased to the original one,
# then this evaluation mode will not work. It works for ndarrays.
order_outputs = copy.copy(env.equivalence_tracker.all_results_ever) order_outputs = copy.copy(env.equivalence_tracker.all_results_ever)
order_outputs.reverse() order_outputs.reverse()
order = graph.io_toposort(env.inputs, order_outputs) order = graph.io_toposort(env.inputs, order_outputs)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage) input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage)
...@@ -235,32 +246,61 @@ class OptCheckLinker(OpWiseCLinker): ...@@ -235,32 +246,61 @@ class OptCheckLinker(OpWiseCLinker):
problematic = set() problematic = set()
r_vals = {} r_vals = {}
assert len(thunks) == len(order) assert len(thunks) == len(order)
for r in env.inputs:
r_vals[r] = copy.copy(storage_map[r][0])
# compute the value of all results
for i, (thunk, node) in enumerate(zip(thunks, order)): for i, (thunk, node) in enumerate(zip(thunks, order)):
thunk() thunk()
for r in node.outputs: for r in node.outputs:
r_set = env.equivalence_tracker.equiv[r] assert r not in r_vals
this_r_val = copy.copy(storage_map[r][0]) this_r_val = copy.copy(storage_map[r][0])
r_vals[r] = this_r_val r_vals[r] = this_r_val
assert this_r_val is not None
if id(r_set) not in equiv_vals:
#print 'get correct', r_set
equiv_vals[id(r_set)] = this_r_val
else:
correct_r_val = equiv_vals[id(r_set)]
# TODO: use r.type.val_cmp(correct_r_val, this_r_val)
# That function doesn't exist yet though..
if type(correct_r_val) != type(this_r_val):
problematic.add(r)
elif type(correct_r_val) is numpy.ndarray:
if not numpy.allclose(correct_r_val, this_r_val):
problematic.add(r)
else:
print 'Ignoring comparison of instances of', type(correct_r_val)
if problematic: # iterate over results looking for values that don't match the values of the
# results they replaced. This is the sign of a broken optimization.
for i, (thunk, node) in enumerate(zip(thunks, order)):
for new_r in node.outputs:
for reason, r in env.equivalence_tracker.reasons[new_r]:
problem = False
#check if the value for new_r doesn't match the value for r
new_r_val = r_vals[new_r]
r_val = r_vals[r]
if type(new_r_val) != type(r_val):
problem = True
elif type(new_r_val) is numpy.ndarray:
if not numpy.allclose(new_r_val, r_val):
problem = True
else:
print >> sys.stderr, 'WARNING: OptCheck Ignoring comparison of instances of', type(new_r_val)
if problem:
print "OPTCHECK FAILURE"
print " Result:", id(new_r), new_r
print " Op", new_r.owner
print " Value Type:", type(new_r_val)
print " Old Value: ", r_val
print " Value: ", new_r_val
print " Reason: ", [(str(reason), id(old_r)) for reason, old_r in env.equivalence_tracker.reasons[new_r]]
print " Snapshots:"
for s in env.equivalence_tracker.snapshots[new_r]:
print " BEFORE"
print s[1]
print " AFTER"
print s[2]
print ""
# There is no point in continuing to check for more problems,
# because the incorrect result detected here will cause
# subsequent outputs to be incorrect.
raise Exception("OptCheckFailure")
if 0:
#print out the summary of the first problematic equivalence group #print out the summary of the first problematic equivalence group
min_member = [] min_member = []
for problem_r in problematic: for problem_r in problematic:
...@@ -273,22 +313,7 @@ class OptCheckLinker(OpWiseCLinker): ...@@ -273,22 +313,7 @@ class OptCheckLinker(OpWiseCLinker):
problematic_set = min_member[0][1] problematic_set = min_member[0][1]
print "OPTCHECK FAILURE"
for r in problematic_set:
print " Result:", id(r), r
print " Op", r.owner
print " Value Type:", type(r_vals[r])
print " Value: ", r_vals[r]
print " Reason: ", [(str(reason), id(old_r)) for reason, old_r in env.equivalence_tracker.reasons[r]]
print " Snapshots:"
for s in env.equivalence_tracker.snapshots[r]:
print " BEFORE"
print s[1]
print " AFTER"
print s[2]
print ""
raise Exception("OptCheckFailure")
except: except:
raise_with_op(node) raise_with_op(node)
......
...@@ -403,17 +403,29 @@ class GemmLocalOptimizer(LocalOptimizer): ...@@ -403,17 +403,29 @@ class GemmLocalOptimizer(LocalOptimizer):
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR) rval = beta_L_plus_alpha_M(sL, mL, -sR, mR)
return rval return rval
if node.op == T.add: if node.op == T.add:
# arguments of the form scalar * matrix
sM_list = [] sM_list = []
# arguments that can be interpreted as scalar * matrix
sM_orig = []
# arguments not of the form scalar * matrix (i.e., vectors, scalars)
other_inputs = [] other_inputs = []
for input in node.inputs: for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input) tmp = _as_isolated_scalar_times_matrix(input)
if tmp: if tmp:
sM_list.append(tmp) sM_list.append(tmp)
sM_orig.append(input)
elif _is_real_matrix(input): elif _is_real_matrix(input):
sM_list.append((1.0, input)) sM_list.append((1.0, input))
sM_orig.append(input)
else: else:
other_inputs.append(input) other_inputs.append(input)
assert len(sM_list) == len(sM_orig)
assert len(sM_list) + len(other_inputs) == len(node.inputs)
if len(sM_list) == 2: if len(sM_list) == 2:
(sL, mL), (sR, mR) = sM_list (sL, mL), (sR, mR) = sM_list
gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR) gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR)
...@@ -425,16 +437,34 @@ class GemmLocalOptimizer(LocalOptimizer): ...@@ -425,16 +437,34 @@ class GemmLocalOptimizer(LocalOptimizer):
else: else:
return gemm_of_sM_list return gemm_of_sM_list
else: else:
# Try every pair in the sM_list, trying to turn it into a gemm operation
for i in xrange(len(sM_list) - 1): for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)): for j in xrange(i+1, len(sM_list)):
assert i != j
sL, mL = sM_list[i] sL, mL = sM_list[i]
sR, mR = sM_list[j] sR, mR = sM_list[j]
gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR) gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR)
if gemm_of_sM_list: if gemm_of_sM_list:
assert len(gemm_of_sM_list) == 1 assert len(gemm_of_sM_list) == 1
inputs_without_ij = \ inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)] [input for k, input in enumerate(sM_orig) if k not in (i,j)]
return [T.add( *(inputs_without_ij + gemm_of_sM_list + other_inputs))]
new_add_inputs = (inputs_without_ij + gemm_of_sM_list + other_inputs)
if False: #SUPER DEBUG MODE :(
if len(new_add_inputs) + 1 != len(node.inputs):
print 'inputs', node.inputs
print 'sM, other', sM_list, other_inputs
print 'i,j', i, j
print 'gemm', gemm_of_sM_list
print 'without ij', inputs_without_ij
print 'new inputs', new_add_inputs
sys.exit(1)
# this should be True because we've combined a pair of arguments
# into a single GEMM
assert len(new_add_inputs) + 1 == len(node.inputs)
return [T.add(*new_add_inputs)]
return False return False
@staticmethod @staticmethod
...@@ -443,7 +473,7 @@ class GemmLocalOptimizer(LocalOptimizer): ...@@ -443,7 +473,7 @@ class GemmLocalOptimizer(LocalOptimizer):
if not isinstance(exc, InconsistencyError): if not isinstance(exc, InconsistencyError):
traceback.print_exc() traceback.print_exc()
else: else:
#print 'GEMM caused cycle, forget it.' #print 'GEMM caused cycle, it happens.'
pass pass
@staticmethod @staticmethod
......
...@@ -142,6 +142,9 @@ class test_greedy_distribute(unittest.TestCase): ...@@ -142,6 +142,9 @@ class test_greedy_distribute(unittest.TestCase):
r1 = f(4,1.e-6, [1.5,2], [2.3,3.1]) r1 = f(4,1.e-6, [1.5,2], [2.3,3.1])
r2 = f(4,1.e-6, [1.5,2], [2.3,3.1]) r2 = f(4,1.e-6, [1.5,2], [2.3,3.1])
assert numpy.all(r0 == r1)
assert numpy.all(r0 == r2)
class test_canonize(unittest.TestCase): class test_canonize(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论