提交 e8aff982 authored 作者: Frederic's avatar Frederic

Faster opt by not doing useless stuff.

This make the Canonizer take 4x less time in a test case.
上级 2b7b0305
......@@ -2867,12 +2867,13 @@ class Canonizer(gof.LocalOptimizer):
else:
return v
def simplify(self, num, denum):
def simplify(self, num, denum, out_type):
"""
Shorthand for:
self.simplify_constants(*self.simplify_factors(num, denum))
"""
rval = self.simplify_constants(*self.simplify_factors(num, denum))
rval = self.simplify_constants(*self.simplify_factors(num, denum),
out_type=out_type)
for reason, simplifier in self.external_simplifiers:
# TODO: document that 'reason' is associated with this
# simplification to help auditing when things go
......@@ -2896,7 +2897,7 @@ class Canonizer(gof.LocalOptimizer):
denum.remove(v)
return num, denum
def simplify_constants(self, orig_num, orig_denum):
def simplify_constants(self, orig_num, orig_denum, out_type=None):
"""
Finds all constants in orig_num and orig_denum (using
......@@ -2914,7 +2915,6 @@ class Canonizer(gof.LocalOptimizer):
# Lists representing the numerator and denumerator
num, denum = list(orig_num), list(orig_denum)
out_type = self.merge_num_denum(orig_num, orig_denum).type
# Lists representing the *constant* elements of num and denum
numct, denumct = [], []
......@@ -3001,7 +3001,7 @@ class Canonizer(gof.LocalOptimizer):
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = self.simplify(list(orig_num), list(orig_denum))
num, denum = self.simplify(list(orig_num), list(orig_denum), out.type)
def same(x, y):
return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in
......@@ -3873,7 +3873,8 @@ register_canonicalize(local_add_canonizer, name='local_add_canonizer')
##################
def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
def distribute_greedy(pos_pairs, neg_pairs, num, denum,
out_type, minscore=0):
# each pair in pos_pairs and neg_pairs is a num/denum pair. this
# function attempts to add num and denum to the corresponding parts
# of each pair, and counts how many multiplications/divisions can
......@@ -3889,10 +3890,10 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
# score is number of operations saved, higher is better
score = len(num) + div_cost * len(denum)
new_pos_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
[(n + num, d + denum) for (n, d)
[(n + num, d + denum, out_type) for (n, d)
in pos_pairs]))
new_neg_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
[(n + num, d + denum) for (n, d)
[(n + num, d + denum, out_type) for (n, d)
in neg_pairs]))
for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs +
new_neg_pairs):
......@@ -3905,7 +3906,7 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
return True, new_pos_pairs, new_neg_pairs
def attempt_distribution(factor, num, denum):
def attempt_distribution(factor, num, denum, out_type):
# we try to insert each num and each denum in the factor
# returns: changes?, new_factor, new_num, new_denum
# if there are changes, new_num and new_denum contain all the numerators
......@@ -3918,13 +3919,13 @@ def attempt_distribution(factor, num, denum):
change = False
for n in list(num):
success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs,
neg_pairs, [n], [])
neg_pairs, [n], [], out_type)
if success:
change = True
num.remove(n)
for d in list(denum):
success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs,
neg_pairs, [], [d])
neg_pairs, [], [d], out_type)
if success:
change = True
denum.remove(d)
......@@ -3969,12 +3970,13 @@ def local_greedy_distributor(node):
change = False
out_type = out.type
for candidate in list(num):
if candidate not in num:
continue
num.remove(candidate)
_change, candidate, num, denum = attempt_distribution(candidate,
num, denum)
num, denum, out_type)
change |= _change
new_num.append(candidate)
......@@ -3983,7 +3985,7 @@ def local_greedy_distributor(node):
continue
denum.remove(candidate)
_change, candidate, denum, num = attempt_distribution(candidate,
denum, num)
denum, num, out_type)
change |= _change
new_denum.append(candidate)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论