提交 e8aff982 authored 作者: Frederic's avatar Frederic

Faster opt by not doing useless stuff.

This make the Canonizer take 4x less time in a test case.
上级 2b7b0305
...@@ -2867,12 +2867,13 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2867,12 +2867,13 @@ class Canonizer(gof.LocalOptimizer):
else: else:
return v return v
def simplify(self, num, denum): def simplify(self, num, denum, out_type):
""" """
Shorthand for: Shorthand for:
self.simplify_constants(*self.simplify_factors(num, denum)) self.simplify_constants(*self.simplify_factors(num, denum))
""" """
rval = self.simplify_constants(*self.simplify_factors(num, denum)) rval = self.simplify_constants(*self.simplify_factors(num, denum),
out_type=out_type)
for reason, simplifier in self.external_simplifiers: for reason, simplifier in self.external_simplifiers:
# TODO: document that 'reason' is associated with this # TODO: document that 'reason' is associated with this
# simplification to help auditing when things go # simplification to help auditing when things go
...@@ -2896,7 +2897,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2896,7 +2897,7 @@ class Canonizer(gof.LocalOptimizer):
denum.remove(v) denum.remove(v)
return num, denum return num, denum
def simplify_constants(self, orig_num, orig_denum): def simplify_constants(self, orig_num, orig_denum, out_type=None):
""" """
Finds all constants in orig_num and orig_denum (using Finds all constants in orig_num and orig_denum (using
...@@ -2914,7 +2915,6 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2914,7 +2915,6 @@ class Canonizer(gof.LocalOptimizer):
# Lists representing the numerator and denumerator # Lists representing the numerator and denumerator
num, denum = list(orig_num), list(orig_denum) num, denum = list(orig_num), list(orig_denum)
out_type = self.merge_num_denum(orig_num, orig_denum).type
# Lists representing the *constant* elements of num and denum # Lists representing the *constant* elements of num and denum
numct, denumct = [], [] numct, denumct = [], []
...@@ -3001,7 +3001,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3001,7 +3001,7 @@ class Canonizer(gof.LocalOptimizer):
# Here we make the canonical version of the graph around this node # Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify # See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0]) orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = self.simplify(list(orig_num), list(orig_denum)) num, denum = self.simplify(list(orig_num), list(orig_denum), out.type)
def same(x, y): def same(x, y):
return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in
...@@ -3873,7 +3873,8 @@ register_canonicalize(local_add_canonizer, name='local_add_canonizer') ...@@ -3873,7 +3873,8 @@ register_canonicalize(local_add_canonizer, name='local_add_canonizer')
################## ##################
def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0): def distribute_greedy(pos_pairs, neg_pairs, num, denum,
out_type, minscore=0):
# each pair in pos_pairs and neg_pairs is a num/denum pair. this # each pair in pos_pairs and neg_pairs is a num/denum pair. this
# function attempts to add num and denum to the corresponding parts # function attempts to add num and denum to the corresponding parts
# of each pair, and counts how many multiplications/divisions can # of each pair, and counts how many multiplications/divisions can
...@@ -3889,10 +3890,10 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0): ...@@ -3889,10 +3890,10 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
# score is number of operations saved, higher is better # score is number of operations saved, higher is better
score = len(num) + div_cost * len(denum) score = len(num) + div_cost * len(denum)
new_pos_pairs = list(itertools.starmap(local_mul_canonizer.simplify, new_pos_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
[(n + num, d + denum) for (n, d) [(n + num, d + denum, out_type) for (n, d)
in pos_pairs])) in pos_pairs]))
new_neg_pairs = list(itertools.starmap(local_mul_canonizer.simplify, new_neg_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
[(n + num, d + denum) for (n, d) [(n + num, d + denum, out_type) for (n, d)
in neg_pairs])) in neg_pairs]))
for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs +
new_neg_pairs): new_neg_pairs):
...@@ -3905,7 +3906,7 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0): ...@@ -3905,7 +3906,7 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
return True, new_pos_pairs, new_neg_pairs return True, new_pos_pairs, new_neg_pairs
def attempt_distribution(factor, num, denum): def attempt_distribution(factor, num, denum, out_type):
# we try to insert each num and each denum in the factor # we try to insert each num and each denum in the factor
# returns: changes?, new_factor, new_num, new_denum # returns: changes?, new_factor, new_num, new_denum
# if there are changes, new_num and new_denum contain all the numerators # if there are changes, new_num and new_denum contain all the numerators
...@@ -3918,13 +3919,13 @@ def attempt_distribution(factor, num, denum): ...@@ -3918,13 +3919,13 @@ def attempt_distribution(factor, num, denum):
change = False change = False
for n in list(num): for n in list(num):
success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs, success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs,
neg_pairs, [n], []) neg_pairs, [n], [], out_type)
if success: if success:
change = True change = True
num.remove(n) num.remove(n)
for d in list(denum): for d in list(denum):
success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs, success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs,
neg_pairs, [], [d]) neg_pairs, [], [d], out_type)
if success: if success:
change = True change = True
denum.remove(d) denum.remove(d)
...@@ -3969,12 +3970,13 @@ def local_greedy_distributor(node): ...@@ -3969,12 +3970,13 @@ def local_greedy_distributor(node):
change = False change = False
out_type = out.type
for candidate in list(num): for candidate in list(num):
if candidate not in num: if candidate not in num:
continue continue
num.remove(candidate) num.remove(candidate)
_change, candidate, num, denum = attempt_distribution(candidate, _change, candidate, num, denum = attempt_distribution(candidate,
num, denum) num, denum, out_type)
change |= _change change |= _change
new_num.append(candidate) new_num.append(candidate)
...@@ -3983,7 +3985,7 @@ def local_greedy_distributor(node): ...@@ -3983,7 +3985,7 @@ def local_greedy_distributor(node):
continue continue
denum.remove(candidate) denum.remove(candidate)
_change, candidate, denum, num = attempt_distribution(candidate, _change, candidate, denum, num = attempt_distribution(candidate,
denum, num) denum, num, out_type)
change |= _change change |= _change
new_denum.append(candidate) new_denum.append(candidate)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论