提交 58e51d18 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added math_optimizer

上级 2e8d00e9
...@@ -102,6 +102,19 @@ from theano.tensor import * ...@@ -102,6 +102,19 @@ from theano.tensor import *
from theano.sandbox import pprint from theano.sandbox import pprint
class _test_greedy_distribute(unittest.TestCase):
def test_main(self):
a, b, c, d, x, y, z = matrices('abcdxyz')
e = (a/z + b/x) * x * z
g = Env([a,b,c,d,x,y,z], [e])
print pprint.pp.process(g.outputs[0])
mul_canonizer.optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_greedy_distributor), order = 'out_to_in').optimize(g)
print pprint.pp.process(g.outputs[0])
class _test_canonize(unittest.TestCase): class _test_canonize(unittest.TestCase):
def test_muldiv(self): def test_muldiv(self):
......
...@@ -615,6 +615,9 @@ class Log(UnaryScalarOp): ...@@ -615,6 +615,9 @@ class Log(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / x, return gz / x,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
#todo: the version using log2 seems to be very slightly faster
# on some machines for some reason, check if it's worth switching
#return "%(z)s = log2(%(x)s) * 0.69314718055994529;" % locals()
return "%(z)s = log(%(x)s);" % locals() return "%(z)s = log(%(x)s);" % locals()
log = Log(upgrade_to_float, name = 'log') log = Log(upgrade_to_float, name = 'log')
......
...@@ -6,6 +6,7 @@ import scalar ...@@ -6,6 +6,7 @@ import scalar
import tensor as T import tensor as T
import numpy as N import numpy as N
import operator import operator
import itertools
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0) # Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
...@@ -332,11 +333,14 @@ class Canonizer(gof.LocalOptimizer): ...@@ -332,11 +333,14 @@ class Canonizer(gof.LocalOptimizer):
return self.inverse(self.merge_num_denum(num, []), return self.inverse(self.merge_num_denum(num, []),
self.merge_num_denum(denum, [])) self.merge_num_denum(denum, []))
def get_constant(self, v): @classmethod
def get_constant(cls, v):
if isinstance(v, N.generic):
return v
if isinstance(v, gof.Constant): if isinstance(v, gof.Constant):
return v.data return v.data
if v.owner and isinstance(v.owner.op, DimShuffle): if v.owner and isinstance(v.owner.op, DimShuffle):
return self.get_constant(v.owner.inputs[0]) return cls.get_constant(v.owner.inputs[0])
return None return None
def simplify(self, num, denum): def simplify(self, num, denum):
...@@ -366,7 +370,9 @@ class Canonizer(gof.LocalOptimizer): ...@@ -366,7 +370,9 @@ class Canonizer(gof.LocalOptimizer):
denum.remove(v) denum.remove(v)
denumct.append(ct) denumct.append(ct)
ct = self.calculate(numct, denumct, aslist = True) ct = self.calculate(numct, denumct, aslist = True)
if len(ct) and ncc == 1 and dcc == 0: # if len(ct) and ncc == 1 and dcc == 0:
# return orig_num, orig_denum
if orig_num and ct == self.get_constant(orig_num[0]):
return orig_num, orig_denum return orig_num, orig_denum
return ct + num, denum return ct + num, denum
...@@ -398,6 +404,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -398,6 +404,7 @@ class Canonizer(gof.LocalOptimizer):
new = T.fill(out, new) new = T.fill(out, new)
return [new] return [new]
def mul_calculate(num, denum, aslist = False): def mul_calculate(num, denum, aslist = False):
v = reduce(N.multiply, num, 1.0) / reduce(N.multiply, denum, 1.0) v = reduce(N.multiply, num, 1.0) / reduce(N.multiply, denum, 1.0)
if aslist: if aslist:
...@@ -408,7 +415,26 @@ def mul_calculate(num, denum, aslist = False): ...@@ -408,7 +415,26 @@ def mul_calculate(num, denum, aslist = False):
return v return v
local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate) local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate)
mul_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_mul_canonizer, local_fill_sink), order = 'in_to_out')
@gof.local_optimizer
def local_neg_to_mul(node):
if node.op == T.neg:
return [-1.0 * node.inputs[0]]
else:
return False
@gof.local_optimizer
def local_mul_to_neg(node):
if node.op == T.mul and local_mul_canonizer.get_constant(node.inputs[0]) == -1.0:
return [-local_mul_canonizer.merge_num_denum(node.inputs[1:], [])]
else:
return False
neg_to_mul = gof.TopoOptimizer(gof.LocalOptGroup(local_neg_to_mul), order = 'out_to_in')
mul_to_neg = gof.TopoOptimizer(gof.LocalOptGroup(local_mul_to_neg), order = 'out_to_in')
mul_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, local_fill_sink), order = 'in_to_out')
def add_calculate(num, denum, aslist = False): def add_calculate(num, denum, aslist = False):
v = reduce(N.add, num, 0.0) - reduce(N.add, denum, 0.0) v = reduce(N.add, num, 0.0) - reduce(N.add, denum, 0.0)
...@@ -420,7 +446,7 @@ def add_calculate(num, denum, aslist = False): ...@@ -420,7 +446,7 @@ def add_calculate(num, denum, aslist = False):
return v return v
local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate) local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate)
add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_fill_sink), order = 'in_to_out') add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_fill_cut, local_fill_sink), order = 'in_to_out')
################## ##################
...@@ -429,49 +455,145 @@ add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_f ...@@ -429,49 +455,145 @@ add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_f
def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore = 0): def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore = 0):
score = len(num) + len(denum) # score is number of operations saved, higher is better # each pair in pos_pairs and neg_pairs is a num/denum pair. this
new_pos_pairs = itertools.starmap(local_mul_canonizer.simplify, # function attempts to add num and denum to the corresponding parts
[(n+num, d+denum) for (n, d) in plus_pairs]) # of each pair, and counts how many multiplications/divisions can
new_neg_pairs = itertools.starmap(local_mul_canonizer.simplify, # be saved in that way.
[(n+num, d+denum) for (n, d) in plus_pairs])
# each division is counted like div_cost multiplications
# (typically, division costs more so we are willing to multiply more
# in order to divide less)
# 1.5 was obtained through an informal test and may very well be
# platform dependent
div_cost = 1.5
score = len(num) + div_cost * len(denum) # score is number of operations saved, higher is better
new_pos_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
[(n+num, d+denum) for (n, d) in pos_pairs]))
new_neg_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
[(n+num, d+denum) for (n, d) in neg_pairs]))
for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs): for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs):
# We calculate how many operations we are saving with the new num and denum # We calculate how many operations we are saving with the new num and denum
score += len(n) + len(d) - len(nn) - len(dd) score += len(n) + div_cost * len(d) - len(nn) - div_cost * len(dd)
if score < minscore: if score <= minscore:
# the change is not applied because it adds too many operations
return False, pos_pairs, neg_pairs return False, pos_pairs, neg_pairs
return True, new_pos_pairs, new_neg_pairs return True, new_pos_pairs, new_neg_pairs
def attempt_distribution(factor, num, denum):
# we try to insert each num and each denum in the factor
# returns: changes?, new_factor, new_num, new_denum
# if there are changes, new_num and new_denum contain all the numerators
# and denumerators that could not be distributed in the factor
pos, neg = local_add_canonizer.get_num_denum(factor)
if len(pos) == 1 and not neg:
return False, factor, num, denum
pos_pairs = map(local_mul_canonizer.get_num_denum, pos)
neg_pairs = map(local_mul_canonizer.get_num_denum, neg)
change = False
for n in list(num):
success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs, neg_pairs, [n], [])
if success:
change = True
num.remove(n)
for d in list(denum):
success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs, neg_pairs, [], [d])
if success:
change = True
denum.remove(d)
if not change:
return change, factor, num, denum
else:
return change, local_add_canonizer.merge_num_denum(
list(itertools.starmap(local_mul_canonizer.merge_num_denum, pos_pairs)),
list(itertools.starmap(local_mul_canonizer.merge_num_denum, neg_pairs))), num, denum
@gof.local_optimizer @gof.local_optimizer
def local_greedy_distributor(node): def local_greedy_distributor(node):
""" """
This optimization tries to apply distributivity of multiplication
to addition in order to reduce the number of multiplications
and/or divisions that must be done. The algorithm weighs division
more than multiplication to account for the former's slightly
greater computational cost.
The following expressions are simplified: The following expressions are simplified:
((a/x + b/y) * x * y) --> a*y + b*x 1. ((a/x + b/y) * x * y) --> a*y + b*x
((a/x + b) * x) --> a + b*x 2. ((a/x + b) * x) --> a + b*x
The following expressions are not simplified:
3. ((a + b) * x) -/-> a*x + b*x
The following expressions are not: This optimization aims to reduce computational cost. It may also
((a + b) * x) -X-> a*x + b*x increase numerical stability, e.g. when x and/or y tend to 0 in
example 1.
""" """
out = node.outputs[0] out = node.outputs[0]
num, denum = local_mul_canonizer.get_num_denum(out) num, denum = local_mul_canonizer.get_num_denum(out)
if len(num) == 1 and not denum: if len(num) == 1 and not denum:
return False return False
new_num = []
for entry in num: new_num, new_denum = [], []
pos, neg = local_add_canonizer.get_num_denum(entry)
if len(pos) == 1 and not neg: change = False
new_num.append(entry)
for candidate in list(num):
if candidate not in num:
continue continue
pos_pairs = map(local_mul_canonizer.get_num_denum, pos) num.remove(candidate)
neg_pairs = map(local_mul_canonizer.get_num_denum, neg) _change, candidate, num, denum = attempt_distribution(candidate, num, denum)
change |= _change
if change:
new_num.append(candidate)
for candidate in list(denum):
if candidate not in denum:
continue
denum.remove(candidate)
_change, candidate, denum, num = attempt_distribution(candidate, denum, num)
change |= _change
if change:
new_denum.append(candidate)
if not change:
return False
new_num += num
new_denum += denum
return [local_mul_canonizer.merge_num_denum(new_num, new_denum)]
def out2in(*local_opts):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), order = 'out_to_in')
def in2out(*local_opts):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), order = 'in_to_out')
def _math_optimizer():
pass_1 = in2out(local_fill_sink)
pass_2 = out2in(local_dimshuffle_lift, local_shape_lift, local_fill_lift)#, local_fill_cut)
pass_3 = out2in(local_subtensor_make_vector, local_fill_cut)
canonizer = in2out(local_add_canonizer,
local_mul_canonizer,
local_fill_sink)
pass_4 = out2in(local_greedy_distributor)
return gof.SeqOptimizer(pass_1,
pass_2,
pass_3,
neg_to_mul,
canonizer,
pass_4,
mul_to_neg)
math_optimizer = _math_optimizer()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论