提交 11be7be2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

optimizer fiesta

上级 deff95dc
......@@ -3,12 +3,12 @@
import unittest
import gof
from tensor_opt import *
import tensor
from tensor import Tensor
from gof import Env
from elemwise import DimShuffle
from theano import gof
from theano.tensor_opt import *
from theano import tensor
from theano.tensor import Tensor
from theano.gof import Env
from theano.elemwise import DimShuffle
import numpy
#import scalar_opt
......@@ -68,7 +68,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]")
lift_dimshuffle.optimize(g)
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]")
def test_merge2(self):
......@@ -76,7 +76,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
lift_dimshuffle.optimize(g)
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
def test_elim3(self):
......@@ -84,7 +84,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
lift_dimshuffle.optimize(g)
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]", str(g))
def test_lift(self):
......@@ -92,10 +92,134 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = x + y + z
g = Env([x, y, z], [e])
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
lift_dimshuffle.optimize(g)
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
from theano.tensor import *
from theano.sandbox import pprint
class _test_canonize(unittest.TestCase):
def test_muldiv(self):
x, y, z = matrices('xyz')
a, b, c, d = matrices('abcd')
# e = (2.0 * x) / (2.0 * y)
# e = (2.0 * x) / (4.0 * y)
# e = x / (y / z)
# e = (x * y) / x
# e = (x / y) * (y / z) * (z / x)
# e = (a / b) * (b / c) * (c / d)
# e = (a * b) / (b * c) / (c * d)
# e = 2 * x / 2
# e = x / y / x
e = (x / x) * (y / y)
g = Env([x, y, z, a, b, c, d], [e])
print pprint.pp.process(g.outputs[0])
mul_canonizer.optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g)
print pprint.pp.process(g.outputs[0])
# def test_plusmin(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# # e = x - x
# # e = (2.0 + x) - (2.0 + y)
# # e = (2.0 + x) - (4.0 + y)
# # e = x - (y - z)
# # e = (x + y) - x
# # e = (x - y) + (y - z) + (z - x)
# # e = (a - b) + (b - c) + (c - d)
# # e = x + -y
# # e = a - b - b + a + b + c + b - c
# # e = x + log(y) - x + y
# e = 2.0 + x + 4.0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs: sum(inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_both(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# e0 = (x * y / x)
# e = e0 + e0 - e0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn).optimize(g)
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_group_powers(self):
# x, y, z, a, b, c, d = floats('xyzabcd')
###################
# c1, c2 = constant(1.), constant(2.)
# #e = pow(x, c1) * pow(x, y) / pow(x, 7.0) # <-- fucked
# #f = -- moving from div(mul.out, pow.out) to pow(x, sub.out)
# e = div(mul(pow(x, 2.0), pow(x, y)), pow(x, 7.0))
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# print g.inputs, g.outputs, g.orphans
# f = sub(add(2.0, y), add(7.0))
# g.replace(e, pow(x, f))
# print g
# print g.inputs, g.outputs, g.orphans
# g.replace(f, sub(add(2.0, y), add(7.0))) # -- moving from sub(add.out, add.out) to sub(add.out, add.out)
# print g
# print g.inputs, g.outputs, g.orphans
###################
# # e = x * exp(y) * exp(z)
# # e = x * pow(x, y) * pow(x, z)
# # e = pow(x, y) / pow(x, z)
# e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) # <-- fucked
# # e = pow(x - x, y)
# # e = pow(x, 2.0 + y - 7.0)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z)
# # e = pow(x, 2.0 + y - 7.0 - z)
# # e = x ** y / x ** y
# # e = x ** y / x ** (y - 1.0)
# # e = exp(x) * a * exp(y) / exp(z)
# g = Env([x, y, z, a, b, c, d], [e])
# g.extend(gof.PrintListener(g))
# print g, g.orphans
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(mul, div, inv, mulfn, divfn, invfn, group_powers).optimize(g)
# print g, g.orphans
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(add, sub, neg, addfn, subfn, negfn).optimize(g)
# print g, g.orphans
# pow2one_float.optimize(g)
# pow2x_float.optimize(g)
# print g, g.orphans
# class _test_cliques(unittest.TestCase):
# def test_straightforward(self):
......
......@@ -237,6 +237,7 @@ class Elemwise(Op):
is left-completed to the greatest number of dimensions with 1s
using DimShuffle.
"""
inputs = map(as_tensor, inputs)
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
......@@ -303,11 +304,10 @@ class Elemwise(Op):
if node is None:
# the gradient contains a constant, translate it as
# an equivalent Tensor of size 1 and proper number of dimensions
b = [1] * nd
res = TensorConstant(Tensor(dtype = r.type.dtype,
broadcastable = b),
numpy.asarray(r.data).reshape(b))
return res
broadcastable = ()),
numpy.asarray(r.data)) # .reshape(b)
return DimShuffle((), ['x']*nd, inplace = True)(res)
new_r = Elemwise(node.op, {})(*[transform(input) for input in node.inputs])
return new_r
ret = []
......
......@@ -18,9 +18,9 @@ from op import \
Op
from opt import \
Optimizer, SeqOptimizer, \
Optimizer, optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \
LocalOptimizer, LocalOptGroup, LocalOpKeyOptGroup, \
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \
OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer
......
......@@ -382,17 +382,17 @@ class Env(utils.object2):
"Same as len(self.clients(r))."
return len(self.clients(r))
def edge(self, r):
return r in self.inputs or r in self.orphans
def follow(self, r):
node = r.owner
if self.edge(r):
return None
else:
if node is None:
raise Exception("what the fuck")
return node.inputs
# def edge(self, r):
# return r in self.inputs or r in self.orphans
# def follow(self, r):
# node = r.owner
# if self.edge(r):
# return None
# else:
# if node is None:
# raise Exception("what the fuck")
# return node.inputs
def check_integrity(self):
"""
......
......@@ -56,6 +56,16 @@ class Optimizer:
pass
class FromFunctionOptimizer(Optimizer):
def __init__(self, fn):
self.apply = fn
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
def optimizer(f):
return FromFunctionOptimizer(f)
class SeqOptimizer(Optimizer, list):
"""
......@@ -137,6 +147,7 @@ class MergeOptimizer(Optimizer):
sig = r.signature()
other_r = inv_cid.get(sig, None)
if other_r is not None:
if r.name: other_r.name = r.name
env.replace_validate(r, other_r)
else:
cid[r] = sig
......@@ -155,8 +166,12 @@ class MergeOptimizer(Optimizer):
success = False
if dup is not None:
success = True
pairs = zip(node.outputs, dup.outputs)
for output, new_output in pairs:
if output.name and not new_output.name:
new_output.name = output.name
try:
env.replace_all_validate(zip(node.outputs, dup.outputs))
env.replace_all_validate(pairs)
except InconsistencyError, e:
success = False
if not success:
......@@ -189,17 +204,27 @@ class LocalOptimizer(utils.object2):
raise utils.AbstractFunctionError()
class FromFunctionLocalOptimizer(LocalOptimizer):
def __init__(self, fn):
self.transform = fn
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
def local_optimizer(f):
return FromFunctionLocalOptimizer(f)
class LocalOptGroup(LocalOptimizer):
def __init__(self, optimizers):
def __init__(self, *optimizers):
self.opts = optimizers
self.reentrant = any(getattr(opt, 'reentrant', True), optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False), optimizers)
self.reentrant = any(getattr(opt, 'reentrant', True) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False) for opt in optimizers)
def transform(self, node):
for opt in self.opts:
repl = opt.transform(node)
if repl is not False:
if repl:
return repl
......@@ -547,3 +572,43 @@ class OpKeyOptimizer(NavigatorOptimizer):
def keep_going(exc, nav, repl_pairs):
pass
#################
### Utilities ###
#################
def _check_chain(r, chain):
chain = list(reversed(chain))
while chain:
elem = chain.pop()
if elem is None:
if not r.owner is None:
return False
elif r.owner is None:
return False
elif isinstance(elem, op.Op):
if not r.owner.op == elem:
return False
else:
try:
if issubclass(elem, op.Op) and not isinstance(r.owner.op, elem):
return False
except TypeError:
return False
if chain:
r = r.owner.inputs[chain.pop()]
return r
def check_chain(r, *chain):
if isinstance(r, graph.Apply):
r = r.outputs[0]
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
import gof
from gof import opt
from elemwise import Elemwise, DimShuffle
import scalar
import tensor as T
import numpy as N
import operator
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
......@@ -26,7 +29,8 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
allow_multiple_clients = False)
class InplaceOptimizer(gof.Optimizer):
@gof.optimizer
def inplace_optimizer(self, env):
"""
Usage: inplace_optimizer.optimize(env)
......@@ -40,39 +44,34 @@ class InplaceOptimizer(gof.Optimizer):
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
def apply(self, env):
for node in list(env.nodes):
op = node.op
if not isinstance(op, Elemwise):
continue
baseline = op.inplace_pattern
candidate_outputs = [i for i in xrange(len(node.outputs)) if i not in baseline]
candidate_inputs = [i for i in xrange(len(node.inputs)) if i not in baseline.values()]
for candidate_output in candidate_outputs:
for candidate_input in candidate_inputs:
inplace_pattern = dict(baseline, **{candidate_output: candidate_input})
try:
new = Elemwise(op.scalar_op, inplace_pattern).make_node(op.inputs)
env.replace_all_validate(dict(zip(node.outputs, new.outputs)))
except:
continue
candidate_inputs.remove(candidate_input)
node = new
baseline = inplace_pattern
break
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
inplace_optimizer = InplaceOptimizer()
class DimShuffleLifter(gof.LocalOptimizer):
for node in list(env.nodes):
op = node.op
if not isinstance(op, Elemwise):
continue
baseline = op.inplace_pattern
candidate_outputs = [i for i in xrange(len(node.outputs)) if i not in baseline]
candidate_inputs = [i for i in xrange(len(node.inputs)) if i not in baseline.values()]
for candidate_output in candidate_outputs:
for candidate_input in candidate_inputs:
inplace_pattern = dict(baseline, **{candidate_output: candidate_input})
try:
new = Elemwise(op.scalar_op, inplace_pattern).make_node(op.inputs)
env.replace_all_validate(dict(zip(node.outputs, new.outputs)))
except:
continue
candidate_inputs.remove(candidate_input)
node = new
baseline = inplace_pattern
break
######################
# DimShuffle lifters #
######################
@gof.local_optimizer
def local_dimshuffle_lift(node):
"""
Usage: lift_dimshuffle.optimize(env)
"Lifts" DimShuffle through Elemwise operations and merges
consecutive DimShuffles. Basically, applies the following
transformations on the whole graph:
......@@ -83,188 +82,824 @@ class DimShuffleLifter(gof.LocalOptimizer):
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
"""
op = node.op
if not isinstance(op, DimShuffle):
return False
input = node.inputs[0]
inode = input.owner
if inode and isinstance(inode.op, Elemwise):
return inode.op.make_node(*[DimShuffle(input.type.broadcastable,
op.new_order,
op.inplace)(input) for input in inode.inputs]).outputs
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in op.new_order]
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
if new_order == range(len(new_order)):
return [iinput]
else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
dimshuffle_lift = gof.TopoOptimizer(local_dimshuffle_lift, order = 'out_to_in')
#################
# Shape lifters #
#################
@gof.local_optimizer
def local_shape_lift_elemwise(node):
"""
shape(elemwise_op(..., x, ...)) -> shape(x)
def transform(self, node):
op = node.op
if not isinstance(op, DimShuffle):
return False
Where x contains the maximal shape information.
"""
if not opt.check_chain(node, T.shape, T.Elemwise):
return False
output = node.inputs[0]
parent = output.owner
input = node.inputs[0]
inode = input.owner
if inode and isinstance(inode.op, Elemwise):
return inode.op.make_node(*[DimShuffle(input.type.broadcastable,
op.new_order,
op.inplace)(input) for input in inode.inputs]).outputs
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in op.new_order]
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
if new_order == range(len(new_order)):
return [iinput]
else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
for input in parent.inputs:
if input.type.broadcastable == output.type.broadcastable:
return T.shape.make_node(input).outputs
return False
lift_dimshuffle = gof.TopoOptimizer(DimShuffleLifter(), order = 'out_to_in')
@gof.local_optimizer
def local_shape_lift_sum(node):
"""
shape(sum{n}(x)) -> [shape(x)[0], ..., shape(x)[n-1], shape(x)[n+1], ...]
"""
if not opt.check_chain(node, T.shape, T.Sum):
return False
input = node.inputs[0].owner.inputs[0]
axis = node.inputs[0].owner.op.axis
if axis is None:# or len(axis) != 1:
axis = range(input.type.ndim)
ish = T.shape(input)
return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs
# return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs
class Canonizer(gof.Optimizer):
@gof.local_optimizer
def local_shape_lift_dot(node):
"""
Simplification tool.
shape(dot(a, b)) -> [shape(a)[0], shape(b)[1]]
"""
if not opt.check_chain(node, T.shape, T.dot):
return False
a, b = node.inputs[0].owner.inputs
return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs
Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
local_shape_lift = opt.LocalOptGroup(local_shape_lift_elemwise,
local_shape_lift_sum,
local_shape_lift_dot)
* mainfn, invfn, recfn: functions that behave just like the previous three
Ops, but on true scalars (e.g. their impl)
* transform: a function that maps (numerator, denominatur) where numerator
and denominator are lists of Result instances, to new lists
where further simplifications may have been applied.
################
# Fill lifters #
################
Examples:
add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
Examples of optimizations mul_canonizer can perform:
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
def encompasses_broadcastable(b1, b2):
if len(b1) < len(b2):
return False
b1 = b1[-len(b2):]
return not any(v1 and not v2 for v1, v2 in zip(b1, b2))
def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)]
@gof.local_optimizer
def local_fill_lift(node):
"""
fill(f(a), b) -> fill(a, b)
If a.type == f(a).type.
fill(a, b) -> b
If a.type == b.type.
"""
if not opt.check_chain(node, T.fill):
return False
model, filling = node.inputs
def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
mb, fb = model.type.broadcastable, filling.type.broadcastable
if model.type.dtype == filling.type.dtype and encompasses_broadcastable(fb, mb):
return [filling]
parent = model.owner
if parent is None:
return False
for input in parent.inputs:
if input.type == model.type:
return [T.fill(input, filling)]
return False
##################
# Subtensor opts #
##################
@gof.local_optimizer
def local_subtensor_make_vector(node):
"""
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
If the index or slice is constant.
"""
if not opt.check_chain(node, T.Subtensor, T.MakeVector):
return False
idxlist = node.op.idx_list
if len(idxlist) != 1:
return False
idx = idxlist[0]
if isinstance(idx, int):
return [node.inputs[0].owner.inputs[idx]]
try:
return T.make_vector(*(node.owner.inputs[0].owner.inputs.__getslice__(idx)))
except TypeError:
return False
##################
# Middleman cuts #
##################
@gof.local_optimizer
def local_fill_cut(node):
"""
f(fill(a,b), c) -> f(b, c)
If c.type == a.type.
"""
if not opt.check_chain(node, T.Elemwise):
return False
output = node.outputs[0]
try:
reference = [input
for input in node.inputs
if input.type == output.type and (not input.owner or input.owner.op != T.fill)][0]
except IndexError:
return False
new_inputs = []
for input in node.inputs:
if opt.check_chain(input, T.fill):
model, filling = input.owner.inputs
if encompasses_broadcastable(reference.type.broadcastable,
filling.type.broadcastable):
new_inputs.append(filling)
continue
new_inputs.append(input)
if new_inputs == node.inputs:
return False
return node.op.make_node(*new_inputs).outputs
@gof.local_optimizer
def local_fill_sink(node):
"""
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
"""
if not (node.op and isinstance(node.op, T.Elemwise) and node.op != T.fill):
return False
models = []
inputs = []
for input in node.inputs:
if input.owner and input.owner.op == T.fill:
models.append(input.owner.inputs[0])
inputs.append(input.owner.inputs[1])
else:
inputs.append(input)
if inputs == node.inputs:
return False
c = node.op(*inputs)
for model in models:
c = T.fill(model, c)
return [c]
################
# Canonization #
################
class Canonizer(gof.LocalOptimizer):
def __init__(self, main, inverse, reciprocal, calculate):
self.main = main
self.inverse = inverse
self.reciprocal = reciprocal
self.mainfn = mainfn
self.invfn = invfn
self.recfn = recfn
self.neutral = mainfn()
self.transform = transform
self.calculate = calculate
def get_num_denum(self, input):
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
return [input], []
num = []
denum = []
parent = input.owner
pairs = [self.get_num_denum(input) for input in parent.inputs]
if parent.op == self.main:
num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
elif parent.op == self.inverse:
num = pairs[0][0] + pairs[1][1]
denum = pairs[0][1] + pairs[1][0]
elif parent.op == self.reciprocal:
num = pairs[0][1]
denum = pairs[0][0]
return num, denum
def merge_num_denum(self, num, denum):
ln, ld = len(num), len(denum)
if not ln and not ld:
return T.as_tensor(self.calculate([], []))
if not ln:
return self.reciprocal(self.merge_num_denum(denum, []))
if not ld:
if ln == 1:
return num[0]
else:
return self.main(*num)
return self.inverse(self.merge_num_denum(num, []),
self.merge_num_denum(denum, []))
def get_constant(self, v):
if isinstance(v, gof.Constant):
return v.data
if v.owner and isinstance(v.owner.op, DimShuffle):
return self.get_constant(v.owner.inputs[0])
return None
def simplify(self, num, denum):
return self.simplify_constants(*self.simplify_factors(num, denum))
def simplify_factors(self, num, denum):
for v in list(num):
if v in denum:
num.remove(v)
denum.remove(v)
return num, denum
def simplify_constants(self, orig_num, orig_denum):
num, denum = list(orig_num), list(orig_denum)
numct, denumct = [], []
ncc, dcc = 0, 0
for v in orig_num:
ct = self.get_constant(v)
if ct is not None:
ncc += 1
num.remove(v)
numct.append(ct)
for v in orig_denum:
ct = self.get_constant(v)
if ct is not None:
dcc += 1
denum.remove(v)
denumct.append(ct)
ct = self.calculate(numct, denumct, aslist = True)
if len(ct) and ncc == 1 and dcc == 0:
return orig_num, orig_denum
return ct + num, denum
def apply(self, env):
def transform(self, node):
op = node.op
inputs = node.inputs
out = node.outputs[0]
if op not in [self.main, self.inverse, self.reciprocal]:
return False
iops = set(input.owner.op for input in inputs if input.owner)
reorg = False
if op == self.main:
reorg = len(iops.intersection([self.main, self.inverse, self.reciprocal])) != 0
elif op == self.inverse:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = list(orig_num), list(orig_denum)
num, denum = self.simplify(num, denum)
if not reorg and orig_num == num and orig_denum == denum:
return False
def edge(r):
return r.owner is None
def follow(r):
return None if r.owner is None else r.owner.inputs
new = self.merge_num_denum(num, denum)
if new.type != out.type:
new = T.fill(out, new)
return [new]
def mul_calculate(num, denum, aslist = False):
v = reduce(N.multiply, num, 1.0) / reduce(N.multiply, denum, 1.0)
if aslist:
if N.all(v == 1):
return []
else:
return [v]
return v
local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate)
mul_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_mul_canonizer, local_fill_sink), order = 'in_to_out')
def add_calculate(num, denum, aslist = False):
v = reduce(N.add, num, 0.0) - reduce(N.add, denum, 0.0)
if aslist:
if N.all(v == 0):
return []
else:
return [v]
return v
local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate)
add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_fill_sink), order = 'in_to_out')
##################
# Distributivity #
##################
def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore = 0):
score = len(num) + len(denum) # score is number of operations saved, higher is better
new_pos_pairs = itertools.starmap(local_mul_canonizer.simplify,
[(n+num, d+denum) for (n, d) in plus_pairs])
new_neg_pairs = itertools.starmap(local_mul_canonizer.simplify,
[(n+num, d+denum) for (n, d) in plus_pairs])
for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs):
# We calculate how many operations we are saving with the new num and denum
score += len(n) + len(d) - len(nn) - len(dd)
if score < minscore:
return False, pos_pairs, neg_pairs
return True, new_pos_pairs, new_neg_pairs
@gof.local_optimizer
def local_greedy_distributor(node):
"""
The following expressions are simplified:
((a/x + b/y) * x * y) --> a*y + b*x
((a/x + b) * x) --> a + b*x
def canonize(r):
The following expressions are not:
((a + b) * x) -X-> a*x + b*x
"""
out = node.outputs[0]
num, denum = local_mul_canonizer.get_num_denum(out)
if len(num) == 1 and not denum:
return False
new_num = []
for entry in num:
pos, neg = local_add_canonizer.get_num_denum(entry)
if len(pos) == 1 and not neg:
new_num.append(entry)
continue
pos_pairs = map(local_mul_canonizer.get_num_denum, pos)
neg_pairs = map(local_mul_canonizer.get_num_denum, neg)
next = follow(r)
if next is None:
return
# class Canonizer(gof.LocalOptimizer):
# def __init__(self, main, inverse, reciprocal, simplify_constants, constant_op):
# self.main = main
# self.inverse = inverse
# self.reciprocal = reciprocal
# self.simplify_constants = simplify_constants
# self.constant_op = constant_op
# def get_num_denum(self, input, depth):
# if depth == 0 or input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
# return [input], []
# num = []
# denum = []
# parent = input.owner
# pairs = [self.get_num_denum(input, depth - 1) for input in parent.inputs]
# if parent.op == self.main:
# num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
# denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
# elif parent.op == self.inverse:
# num = pairs[0][0] + pairs[1][1]
# denum = pairs[0][1] + pairs[1][0]
# elif parent.op == self.reciprocal:
# num = pairs[0][1]
# denum = pairs[0][0]
# return num, denum
# def deep_num_denum(self, node):
# op = node.op
# if op == self.main:
# num, denum = self.get_num_denum(inputs)
# elif op == self.inverse:
# assert len(inputs) == 2
# n1, d1 = self.get_num_denum(inputs[:1])
# n2, d2 = self.get_num_denum(inputs[1:])
# num, denum = n1+d2, d1+n2
# elif op == self.reciprocal:
# denum, num = self.get_num_denum(inputs)
# else:
# num, denum = [node.outputs[0]], []
# return num, denum
# def get_num_denum(self, inputs):
# num = []
# denum = []
# for input in inputs:
# if input.owner is None:
# num.append(input)
# continue
# parent = input.owner
# if parent.op == self.main:
# num += parent.inputs
# elif parent.op == self.inverse:
# num += parent.inputs[:1]
# denum += parent.inputs[1:]
# elif parent.op == self.reciprocal:
# denum += parent.inputs
# else:
# num.append(input)
# return num, denum
# def merge_num_denum(self, num, denum, outtype):
# ln, ld = len(num), len(denum)
# if not ln and not ld:
# return outtype.filter(self.simplify_constants([], []))
# if not ln:
# return self.reciprocal(self.merge_num_denum(denum, [], outtype))
# if not ld:
# if ln == 1:
# return num[0]
# else:
# return self.main(*num)
# return self.inverse(self.merge_num_denum(num, [], outtype),
# self.merge_num_denum(denum, [], outtype))
# def get_constant(self, v):
# if isinstance(v, gof.Constant):
# return v.data
# if v.owner and isinstance(v.owner.op, DimShuffle):
# return self.get_constant(v.owner.inputs[0])
# return None
# def simplify(self, num, denum):
# numct, denumct = [], []
# ncc, dcc = 0, 0
# for v in list(num):
# if v in denum:
# num.remove(v)
# denum.remove(v)
# continue
# ct = self.get_constant(v)
# if ct is not None:
# ncc += 1
# num.remove(v)
# numct.append(ct)
# for v in list(denum):
# ct = self.get_constant(v)
# if ct is not None:
# dcc += 1
# denum.remove(v)
# denumct.append(ct)
# ct = self.simplify_constants(numct, denumct)
# if ct is None:
# return ncc+dcc>0, None, num, denum
# ctop = self.constant_op.get(ct)
# if ctop is not None:
# return True, ctop, num, denum
# return not (ncc==1 and dcc==0), None, [ct]+num, denum
# def transform(self, node):
# op = node.op
# inputs = node.inputs
# if op == self.main:
# num, denum = self.get_num_denum(inputs)
# elif op == self.inverse:
# assert len(inputs) == 2
# n1, d1 = self.get_num_denum(inputs[:1])
# n2, d2 = self.get_num_denum(inputs[1:])
# num, denum = n1+d2, d1+n2
# elif op == self.reciprocal:
# denum, num = self.get_num_denum(inputs)
# else:
# return False
# change, ctop, num2, denum2 = self.simplify(num, denum)
# if change:
# num, denum = num2, denum2
# # print node, ct, num, denum
# # ctop = ct != [] and self.constant_op.get(ct[0], None)
# # if not ctop:
# # num = ct + num
# new = self.merge_num_denum(num, denum, node.outputs[0].type)
# if ctop:
# new = ctop(new)
# print new.owner.op, op, new.owner.inputs, inputs
# if new.owner and new.owner.op == op and all((new_input.owner new.owner.inputs == inputs:
# return False
# return [new]
# @gof.local_optimizer
# def local_cut_middlemen(node):
# op = node.op
# if isinstance(op, Elemwise):
# aaaaaaa
# # @gof.local_optimizer
# # def local_merge_mul(node):
# # op = node.op
# # if op != mul:
# # return False
# # num, denum = _get_num_denum(node.inputs)
# # if num == node.inputs and denum == []:
# # return False
# # return _
# class Lift(gof.LocalOptimizer):
# def __init__(self, op, lifters, chooser):
# self.op = op
# self.lifters = lifters
# self.chooser = chooser
# def op_key(self):
# return self.op
# def transform(self, node):
# if not node.op == self.op:
# return False
# candidates = [node.inputs[0]]
# seen = set(candidates)
# while True:
# candidate = candidates.pop()
# for lifter in self.lifters:
# new_candidates = lifter(candidate)
# if not new_candidates:
# break
# else:
# candidates.append(candidate)
# new_op = self.op(self.chooser(candidates))
# return new_op
# class Canonizer(gof.Optimizer):
# """
# Simplification tool.
# Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
# * main: a suitable Op class that is commutative, associative and takes
# one to an arbitrary number of inputs, e.g. Add or Mul
# * inverse: an Op class such that inverse(main(x, y), y) == x
# e.g. Sub or Div
# * reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
# e.g. Neg or Inv
# * mainfn, invfn, recfn: functions that behave just like the previous three
# Ops, but on true scalars (e.g. their impl)
# * transform: a function that maps (numerator, denominatur) where numerator
# and denominator are lists of Result instances, to new lists
# where further simplifications may have been applied.
# Examples:
# add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
# mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
# Examples of optimizations mul_canonizer can perform:
# x / x -> 1
# (x * y) / x -> y
# x / y / x -> 1 / y
# x / y / z -> x / (y * z)
# x / (y / z) -> (x * z) / y
# (a / b) * (b / c) * (c / d) -> a / d
# (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
# 2 * x / 2 -> x
# """
# def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
# self.main = main
# self.inverse = inverse
# self.reciprocal = reciprocal
# self.mainfn = mainfn
# self.invfn = invfn
# self.recfn = recfn
# self.neutral = mainfn()
# self.transform = transform
# def apply(self, env):
# def edge(r):
# return r.owner is None
# def follow(r):
# return None if r.owner is None else r.owner.inputs
# def canonize(r):
# next = follow(r)
# if next is None:
# return
def flatten(r, nclients_check = True):
# Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
if edge(r):
return [r], []
node = r.owner
op = node.op
# def flatten(r, nclients_check = True):
# # Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# # into a list of numerators and a list of denominators
# # e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
# if edge(r):
# return [r], []
# node = r.owner
# op = node.op
results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs]
if op == self.main and (not nclients_check or env.nclients(r) == 1):
nums = [x[0] for x in results]
denums = [x[1] for x in results]
elif op == self.inverse and (not nclients_check or env.nclients(r) == 1):
# num, denum of the second argument are added to the denum, num respectively
nums = [results[0][0], results[1][1]]
denums = [results[0][1], results[1][0]]
elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1):
# num, denum of the sole argument are added to the denum, num respectively
nums = [results[0][1]]
denums = [results[0][0]]
else:
return [r], []
return reduce(list.__add__, nums), reduce(list.__add__, denums)
num, denum = flatten(r, False)
if (num, denum) == ([r], []):
for input in (follow(r) or []):
canonize(input)
return
# Terms that are both in the num and denum lists cancel each other
for d in list(denum):
if d in list(num):
# list.remove only removes the element once
num.remove(d)
denum.remove(d)
# We identify the constants in num and denum
numct, num = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, num)
denumct, denum = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, denum)
#print numct, num
#print denumct, denum
print num, denum
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
if v != self.neutral:
num.insert(0, C(v))
# We optimize the num and denum lists further if requested
if self.transform is not None:
num, denum = self.transform(env, num, denum)
def make(factors):
# Combines the factors using self.main (aka Mul) depending
# on the number of elements.
n = len(factors)
if n == 0:
return None
elif n == 1:
return factors[0]
else:
return self.main(*factors)
numr, denumr = make(num), make(denum)
# results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs]
# if op == self.main and (not nclients_check or env.nclients(r) == 1):
# nums = [x[0] for x in results]
# denums = [x[1] for x in results]
# elif op == self.inverse and (not nclients_check or env.nclients(r) == 1):
# # num, denum of the second argument are added to the denum, num respectively
# nums = [results[0][0], results[1][1]]
# denums = [results[0][1], results[1][0]]
# elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1):
# # num, denum of the sole argument are added to the denum, num respectively
# nums = [results[0][1]]
# denums = [results[0][0]]
# else:
# return [r], []
# return reduce(list.__add__, nums), reduce(list.__add__, denums)
# num, denum = flatten(r, False)
# if (num, denum) == ([r], []):
# for input in (follow(r) or []):
# canonize(input)
# return
# # Terms that are both in the num and denum lists cancel each other
# for d in list(denum):
# if d in list(num):
# # list.remove only removes the element once
# num.remove(d)
# denum.remove(d)
# # We identify the constants in num and denum
# numct, num = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, num)
# denumct, denum = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, denum)
# #print numct, num
# #print denumct, denum
# print num, denum
# # All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
# v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
# if v != self.neutral:
# num.insert(0, C(v))
# # We optimize the num and denum lists further if requested
# if self.transform is not None:
# num, denum = self.transform(env, num, denum)
# def make(factors):
# # Combines the factors using self.main (aka Mul) depending
# # on the number of elements.
# n = len(factors)
# if n == 0:
# return None
# elif n == 1:
# return factors[0]
# else:
# return self.main(*factors)
# numr, denumr = make(num), make(denum)
if numr is None:
if denumr is None:
# Everything cancelled each other so we're left with
# the neutral element.
new_r = gof.Constant(r.type, self.neutral)
else:
# There's no numerator so we use reciprocal
new_r = self.reciprocal(denumr)
else:
if denumr is None:
new_r = numr
else:
new_r = self.inverse(numr, denumr)
# if numr is None:
# if denumr is None:
# # Everything cancelled each other so we're left with
# # the neutral element.
# new_r = gof.Constant(r.type, self.neutral)
# else:
# # There's no numerator so we use reciprocal
# new_r = self.reciprocal(denumr)
# else:
# if denumr is None:
# new_r = numr
# else:
# new_r = self.inverse(numr, denumr)
# Hopefully this won't complain!
env.replace(r, new_r)
# # Hopefully this won't complain!
# env.replace(r, new_r)
for factor in num + denum:
canonize(factor)
# for factor in num + denum:
# canonize(factor)
for output in env.outputs:
canonize(output)
# for output in env.outputs:
# canonize(output)
_mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
_divfn = lambda x, y: x / y
_invfn = lambda x: 1 / x
mul_canonizer = Canonizer(T.mul, T.div, T.inv, _mulfn, _divfn, _invfn)
# _mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# _divfn = lambda x, y: x / y
# _invfn = lambda x: 1 / x
# mul_canonizer = Canonizer(T.mul, T.div, T.inv, _mulfn, _divfn, _invfn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论