提交 11be7be2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

optimizer fiesta

上级 deff95dc
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
import unittest import unittest
import gof from theano import gof
from tensor_opt import * from theano.tensor_opt import *
import tensor from theano import tensor
from tensor import Tensor from theano.tensor import Tensor
from gof import Env from theano.gof import Env
from elemwise import DimShuffle from theano.elemwise import DimShuffle
import numpy import numpy
#import scalar_opt #import scalar_opt
...@@ -68,7 +68,7 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -68,7 +68,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(x, (1, 0)), (1, 0)) e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]") self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]")
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]") self.failUnless(str(g) == "[x]")
def test_merge2(self): def test_merge2(self):
...@@ -76,7 +76,7 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -76,7 +76,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1)) e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g)) self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g)) self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
def test_elim3(self): def test_elim3(self):
...@@ -84,7 +84,7 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -84,7 +84,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0)) e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g)) self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]", str(g)) self.failUnless(str(g) == "[x]", str(g))
def test_lift(self): def test_lift(self):
...@@ -92,10 +92,134 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -92,10 +92,134 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = x + y + z e = x + y + z
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g)) self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g)) self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
from theano.tensor import *
from theano.sandbox import pprint
class _test_canonize(unittest.TestCase):
def test_muldiv(self):
x, y, z = matrices('xyz')
a, b, c, d = matrices('abcd')
# e = (2.0 * x) / (2.0 * y)
# e = (2.0 * x) / (4.0 * y)
# e = x / (y / z)
# e = (x * y) / x
# e = (x / y) * (y / z) * (z / x)
# e = (a / b) * (b / c) * (c / d)
# e = (a * b) / (b * c) / (c * d)
# e = 2 * x / 2
# e = x / y / x
e = (x / x) * (y / y)
g = Env([x, y, z, a, b, c, d], [e])
print pprint.pp.process(g.outputs[0])
mul_canonizer.optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g)
print pprint.pp.process(g.outputs[0])
# def test_plusmin(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# # e = x - x
# # e = (2.0 + x) - (2.0 + y)
# # e = (2.0 + x) - (4.0 + y)
# # e = x - (y - z)
# # e = (x + y) - x
# # e = (x - y) + (y - z) + (z - x)
# # e = (a - b) + (b - c) + (c - d)
# # e = x + -y
# # e = a - b - b + a + b + c + b - c
# # e = x + log(y) - x + y
# e = 2.0 + x + 4.0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs: sum(inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_both(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# e0 = (x * y / x)
# e = e0 + e0 - e0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn).optimize(g)
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_group_powers(self):
# x, y, z, a, b, c, d = floats('xyzabcd')
###################
# c1, c2 = constant(1.), constant(2.)
# #e = pow(x, c1) * pow(x, y) / pow(x, 7.0) # <-- fucked
# #f = -- moving from div(mul.out, pow.out) to pow(x, sub.out)
# e = div(mul(pow(x, 2.0), pow(x, y)), pow(x, 7.0))
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# print g.inputs, g.outputs, g.orphans
# f = sub(add(2.0, y), add(7.0))
# g.replace(e, pow(x, f))
# print g
# print g.inputs, g.outputs, g.orphans
# g.replace(f, sub(add(2.0, y), add(7.0))) # -- moving from sub(add.out, add.out) to sub(add.out, add.out)
# print g
# print g.inputs, g.outputs, g.orphans
###################
# # e = x * exp(y) * exp(z)
# # e = x * pow(x, y) * pow(x, z)
# # e = pow(x, y) / pow(x, z)
# e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) # <-- fucked
# # e = pow(x - x, y)
# # e = pow(x, 2.0 + y - 7.0)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z)
# # e = pow(x, 2.0 + y - 7.0 - z)
# # e = x ** y / x ** y
# # e = x ** y / x ** (y - 1.0)
# # e = exp(x) * a * exp(y) / exp(z)
# g = Env([x, y, z, a, b, c, d], [e])
# g.extend(gof.PrintListener(g))
# print g, g.orphans
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(mul, div, inv, mulfn, divfn, invfn, group_powers).optimize(g)
# print g, g.orphans
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(add, sub, neg, addfn, subfn, negfn).optimize(g)
# print g, g.orphans
# pow2one_float.optimize(g)
# pow2x_float.optimize(g)
# print g, g.orphans
# class _test_cliques(unittest.TestCase): # class _test_cliques(unittest.TestCase):
# def test_straightforward(self): # def test_straightforward(self):
......
...@@ -237,6 +237,7 @@ class Elemwise(Op): ...@@ -237,6 +237,7 @@ class Elemwise(Op):
is left-completed to the greatest number of dimensions with 1s is left-completed to the greatest number of dimensions with 1s
using DimShuffle. using DimShuffle.
""" """
inputs = map(as_tensor, inputs) inputs = map(as_tensor, inputs)
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs]) shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
...@@ -303,11 +304,10 @@ class Elemwise(Op): ...@@ -303,11 +304,10 @@ class Elemwise(Op):
if node is None: if node is None:
# the gradient contains a constant, translate it as # the gradient contains a constant, translate it as
# an equivalent Tensor of size 1 and proper number of dimensions # an equivalent Tensor of size 1 and proper number of dimensions
b = [1] * nd
res = TensorConstant(Tensor(dtype = r.type.dtype, res = TensorConstant(Tensor(dtype = r.type.dtype,
broadcastable = b), broadcastable = ()),
numpy.asarray(r.data).reshape(b)) numpy.asarray(r.data)) # .reshape(b)
return res return DimShuffle((), ['x']*nd, inplace = True)(res)
new_r = Elemwise(node.op, {})(*[transform(input) for input in node.inputs]) new_r = Elemwise(node.op, {})(*[transform(input) for input in node.inputs])
return new_r return new_r
ret = [] ret = []
......
...@@ -18,9 +18,9 @@ from op import \ ...@@ -18,9 +18,9 @@ from op import \
Op Op
from opt import \ from opt import \
Optimizer, SeqOptimizer, \ Optimizer, optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \ MergeOptimizer, MergeOptMerge, \
LocalOptimizer, LocalOptGroup, LocalOpKeyOptGroup, \ LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \
OpSub, OpRemove, PatternSub, \ OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer
......
...@@ -382,17 +382,17 @@ class Env(utils.object2): ...@@ -382,17 +382,17 @@ class Env(utils.object2):
"Same as len(self.clients(r))." "Same as len(self.clients(r))."
return len(self.clients(r)) return len(self.clients(r))
def edge(self, r): # def edge(self, r):
return r in self.inputs or r in self.orphans # return r in self.inputs or r in self.orphans
def follow(self, r): # def follow(self, r):
node = r.owner # node = r.owner
if self.edge(r): # if self.edge(r):
return None # return None
else: # else:
if node is None: # if node is None:
raise Exception("what the fuck") # raise Exception("what the fuck")
return node.inputs # return node.inputs
def check_integrity(self): def check_integrity(self):
""" """
......
...@@ -56,6 +56,16 @@ class Optimizer: ...@@ -56,6 +56,16 @@ class Optimizer:
pass pass
class FromFunctionOptimizer(Optimizer):
def __init__(self, fn):
self.apply = fn
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
def optimizer(f):
return FromFunctionOptimizer(f)
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
""" """
...@@ -137,6 +147,7 @@ class MergeOptimizer(Optimizer): ...@@ -137,6 +147,7 @@ class MergeOptimizer(Optimizer):
sig = r.signature() sig = r.signature()
other_r = inv_cid.get(sig, None) other_r = inv_cid.get(sig, None)
if other_r is not None: if other_r is not None:
if r.name: other_r.name = r.name
env.replace_validate(r, other_r) env.replace_validate(r, other_r)
else: else:
cid[r] = sig cid[r] = sig
...@@ -155,8 +166,12 @@ class MergeOptimizer(Optimizer): ...@@ -155,8 +166,12 @@ class MergeOptimizer(Optimizer):
success = False success = False
if dup is not None: if dup is not None:
success = True success = True
pairs = zip(node.outputs, dup.outputs)
for output, new_output in pairs:
if output.name and not new_output.name:
new_output.name = output.name
try: try:
env.replace_all_validate(zip(node.outputs, dup.outputs)) env.replace_all_validate(pairs)
except InconsistencyError, e: except InconsistencyError, e:
success = False success = False
if not success: if not success:
...@@ -189,17 +204,27 @@ class LocalOptimizer(utils.object2): ...@@ -189,17 +204,27 @@ class LocalOptimizer(utils.object2):
raise utils.AbstractFunctionError() raise utils.AbstractFunctionError()
class FromFunctionLocalOptimizer(LocalOptimizer):
def __init__(self, fn):
self.transform = fn
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
def local_optimizer(f):
return FromFunctionLocalOptimizer(f)
class LocalOptGroup(LocalOptimizer): class LocalOptGroup(LocalOptimizer):
def __init__(self, optimizers): def __init__(self, *optimizers):
self.opts = optimizers self.opts = optimizers
self.reentrant = any(getattr(opt, 'reentrant', True), optimizers) self.reentrant = any(getattr(opt, 'reentrant', True) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False), optimizers) self.retains_inputs = all(getattr(opt, 'retains_inputs', False) for opt in optimizers)
def transform(self, node): def transform(self, node):
for opt in self.opts: for opt in self.opts:
repl = opt.transform(node) repl = opt.transform(node)
if repl is not False: if repl:
return repl return repl
...@@ -547,3 +572,43 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -547,3 +572,43 @@ class OpKeyOptimizer(NavigatorOptimizer):
def keep_going(exc, nav, repl_pairs): def keep_going(exc, nav, repl_pairs):
pass pass
#################
### Utilities ###
#################
def _check_chain(r, chain):
chain = list(reversed(chain))
while chain:
elem = chain.pop()
if elem is None:
if not r.owner is None:
return False
elif r.owner is None:
return False
elif isinstance(elem, op.Op):
if not r.owner.op == elem:
return False
else:
try:
if issubclass(elem, op.Op) and not isinstance(r.owner.op, elem):
return False
except TypeError:
return False
if chain:
r = r.owner.inputs[chain.pop()]
return r
def check_chain(r, *chain):
if isinstance(r, graph.Apply):
r = r.outputs[0]
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
import gof import gof
from gof import opt
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
import scalar import scalar
import tensor as T import tensor as T
import numpy as N
import operator
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0) # Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
...@@ -26,7 +29,8 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'), ...@@ -26,7 +29,8 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
allow_multiple_clients = False) allow_multiple_clients = False)
class InplaceOptimizer(gof.Optimizer): @gof.optimizer
def inplace_optimizer(self, env):
""" """
Usage: inplace_optimizer.optimize(env) Usage: inplace_optimizer.optimize(env)
...@@ -40,8 +44,6 @@ class InplaceOptimizer(gof.Optimizer): ...@@ -40,8 +44,6 @@ class InplaceOptimizer(gof.Optimizer):
x + y + z -> x += y += z x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y) (x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
""" """
def apply(self, env):
for node in list(env.nodes): for node in list(env.nodes):
op = node.op op = node.op
if not isinstance(op, Elemwise): if not isinstance(op, Elemwise):
...@@ -62,17 +64,14 @@ class InplaceOptimizer(gof.Optimizer): ...@@ -62,17 +64,14 @@ class InplaceOptimizer(gof.Optimizer):
baseline = inplace_pattern baseline = inplace_pattern
break break
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
inplace_optimizer = InplaceOptimizer()
######################
# DimShuffle lifters #
######################
@gof.local_optimizer
class DimShuffleLifter(gof.LocalOptimizer): def local_dimshuffle_lift(node):
""" """
Usage: lift_dimshuffle.optimize(env)
"Lifts" DimShuffle through Elemwise operations and merges "Lifts" DimShuffle through Elemwise operations and merges
consecutive DimShuffles. Basically, applies the following consecutive DimShuffles. Basically, applies the following
transformations on the whole graph: transformations on the whole graph:
...@@ -83,8 +82,6 @@ class DimShuffleLifter(gof.LocalOptimizer): ...@@ -83,8 +82,6 @@ class DimShuffleLifter(gof.LocalOptimizer):
After this transform, clusters of Elemwise operations are After this transform, clusters of Elemwise operations are
void of DimShuffle operations. void of DimShuffle operations.
""" """
def transform(self, node):
op = node.op op = node.op
if not isinstance(op, DimShuffle): if not isinstance(op, DimShuffle):
return False return False
...@@ -104,167 +101,805 @@ class DimShuffleLifter(gof.LocalOptimizer): ...@@ -104,167 +101,805 @@ class DimShuffleLifter(gof.LocalOptimizer):
else: else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
lift_dimshuffle = gof.TopoOptimizer(DimShuffleLifter(), order = 'out_to_in') dimshuffle_lift = gof.TopoOptimizer(local_dimshuffle_lift, order = 'out_to_in')
#################
# Shape lifters #
#################
class Canonizer(gof.Optimizer): @gof.local_optimizer
def local_shape_lift_elemwise(node):
""" """
Simplification tool. shape(elemwise_op(..., x, ...)) -> shape(x)
Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform) Where x contains the maximal shape information.
"""
if not opt.check_chain(node, T.shape, T.Elemwise):
return False
* main: a suitable Op class that is commutative, associative and takes output = node.inputs[0]
one to an arbitrary number of inputs, e.g. Add or Mul parent = output.owner
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* mainfn, invfn, recfn: functions that behave just like the previous three for input in parent.inputs:
Ops, but on true scalars (e.g. their impl) if input.type.broadcastable == output.type.broadcastable:
return T.shape.make_node(input).outputs
* transform: a function that maps (numerator, denominatur) where numerator return False
and denominator are lists of Result instances, to new lists
where further simplifications may have been applied.
Examples: @gof.local_optimizer
add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...) def local_shape_lift_sum(node):
mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...) """
shape(sum{n}(x)) -> [shape(x)[0], ..., shape(x)[n-1], shape(x)[n+1], ...]
Examples of optimizations mul_canonizer can perform: """
x / x -> 1 if not opt.check_chain(node, T.shape, T.Sum):
(x * y) / x -> y return False
x / y / x -> 1 / y
x / y / z -> x / (y * z) input = node.inputs[0].owner.inputs[0]
x / (y / z) -> (x * z) / y axis = node.inputs[0].owner.op.axis
(a / b) * (b / c) * (c / d) -> a / d if axis is None:# or len(axis) != 1:
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y axis = range(input.type.ndim)
2 * x / 2 -> x
ish = T.shape(input)
return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs
# return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs
@gof.local_optimizer
def local_shape_lift_dot(node):
"""
shape(dot(a, b)) -> [shape(a)[0], shape(b)[1]]
"""
if not opt.check_chain(node, T.shape, T.dot):
return False
a, b = node.inputs[0].owner.inputs
return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs
local_shape_lift = opt.LocalOptGroup(local_shape_lift_elemwise,
local_shape_lift_sum,
local_shape_lift_dot)
################
# Fill lifters #
################
def encompasses_broadcastable(b1, b2):
if len(b1) < len(b2):
return False
b1 = b1[-len(b2):]
return not any(v1 and not v2 for v1, v2 in zip(b1, b2))
def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)]
@gof.local_optimizer
def local_fill_lift(node):
"""
fill(f(a), b) -> fill(a, b)
If a.type == f(a).type.
fill(a, b) -> b
If a.type == b.type.
"""
if not opt.check_chain(node, T.fill):
return False
model, filling = node.inputs
mb, fb = model.type.broadcastable, filling.type.broadcastable
if model.type.dtype == filling.type.dtype and encompasses_broadcastable(fb, mb):
return [filling]
parent = model.owner
if parent is None:
return False
for input in parent.inputs:
if input.type == model.type:
return [T.fill(input, filling)]
return False
##################
# Subtensor opts #
##################
@gof.local_optimizer
def local_subtensor_make_vector(node):
"""
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
If the index or slice is constant.
"""
if not opt.check_chain(node, T.Subtensor, T.MakeVector):
return False
idxlist = node.op.idx_list
if len(idxlist) != 1:
return False
idx = idxlist[0]
if isinstance(idx, int):
return [node.inputs[0].owner.inputs[idx]]
try:
return T.make_vector(*(node.owner.inputs[0].owner.inputs.__getslice__(idx)))
except TypeError:
return False
##################
# Middleman cuts #
##################
@gof.local_optimizer
def local_fill_cut(node):
"""
f(fill(a,b), c) -> f(b, c)
If c.type == a.type.
"""
if not opt.check_chain(node, T.Elemwise):
return False
output = node.outputs[0]
try:
reference = [input
for input in node.inputs
if input.type == output.type and (not input.owner or input.owner.op != T.fill)][0]
except IndexError:
return False
new_inputs = []
for input in node.inputs:
if opt.check_chain(input, T.fill):
model, filling = input.owner.inputs
if encompasses_broadcastable(reference.type.broadcastable,
filling.type.broadcastable):
new_inputs.append(filling)
continue
new_inputs.append(input)
if new_inputs == node.inputs:
return False
return node.op.make_node(*new_inputs).outputs
@gof.local_optimizer
def local_fill_sink(node):
"""
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
""" """
if not (node.op and isinstance(node.op, T.Elemwise) and node.op != T.fill):
return False
models = []
inputs = []
for input in node.inputs:
if input.owner and input.owner.op == T.fill:
models.append(input.owner.inputs[0])
inputs.append(input.owner.inputs[1])
else:
inputs.append(input)
if inputs == node.inputs:
return False
c = node.op(*inputs)
for model in models:
c = T.fill(model, c)
return [c]
################
# Canonization #
################
class Canonizer(gof.LocalOptimizer):
def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None): def __init__(self, main, inverse, reciprocal, calculate):
self.main = main self.main = main
self.inverse = inverse self.inverse = inverse
self.reciprocal = reciprocal self.reciprocal = reciprocal
self.mainfn = mainfn self.calculate = calculate
self.invfn = invfn
self.recfn = recfn def get_num_denum(self, input):
self.neutral = mainfn() if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
self.transform = transform return [input], []
num = []
denum = []
parent = input.owner
pairs = [self.get_num_denum(input) for input in parent.inputs]
if parent.op == self.main:
num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
elif parent.op == self.inverse:
num = pairs[0][0] + pairs[1][1]
denum = pairs[0][1] + pairs[1][0]
elif parent.op == self.reciprocal:
num = pairs[0][1]
denum = pairs[0][0]
return num, denum
def merge_num_denum(self, num, denum):
ln, ld = len(num), len(denum)
if not ln and not ld:
return T.as_tensor(self.calculate([], []))
if not ln:
return self.reciprocal(self.merge_num_denum(denum, []))
if not ld:
if ln == 1:
return num[0]
else:
return self.main(*num)
return self.inverse(self.merge_num_denum(num, []),
self.merge_num_denum(denum, []))
def get_constant(self, v):
if isinstance(v, gof.Constant):
return v.data
if v.owner and isinstance(v.owner.op, DimShuffle):
return self.get_constant(v.owner.inputs[0])
return None
def apply(self, env): def simplify(self, num, denum):
return self.simplify_constants(*self.simplify_factors(num, denum))
def simplify_factors(self, num, denum):
for v in list(num):
if v in denum:
num.remove(v)
denum.remove(v)
return num, denum
def simplify_constants(self, orig_num, orig_denum):
num, denum = list(orig_num), list(orig_denum)
numct, denumct = [], []
ncc, dcc = 0, 0
for v in orig_num:
ct = self.get_constant(v)
if ct is not None:
ncc += 1
num.remove(v)
numct.append(ct)
for v in orig_denum:
ct = self.get_constant(v)
if ct is not None:
dcc += 1
denum.remove(v)
denumct.append(ct)
ct = self.calculate(numct, denumct, aslist = True)
if len(ct) and ncc == 1 and dcc == 0:
return orig_num, orig_denum
return ct + num, denum
def edge(r): def transform(self, node):
return r.owner is None op = node.op
def follow(r): inputs = node.inputs
return None if r.owner is None else r.owner.inputs out = node.outputs[0]
if op not in [self.main, self.inverse, self.reciprocal]:
return False
def canonize(r): iops = set(input.owner.op for input in inputs if input.owner)
reorg = False
if op == self.main:
reorg = len(iops.intersection([self.main, self.inverse, self.reciprocal])) != 0
elif op == self.inverse:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
next = follow(r) orig_num, orig_denum = self.get_num_denum(node.outputs[0])
if next is None: num, denum = list(orig_num), list(orig_denum)
return num, denum = self.simplify(num, denum)
def flatten(r, nclients_check = True): if not reorg and orig_num == num and orig_denum == denum:
# Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg) return False
# into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
if edge(r): new = self.merge_num_denum(num, denum)
return [r], [] if new.type != out.type:
node = r.owner new = T.fill(out, new)
op = node.op return [new]
results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs] def mul_calculate(num, denum, aslist = False):
if op == self.main and (not nclients_check or env.nclients(r) == 1): v = reduce(N.multiply, num, 1.0) / reduce(N.multiply, denum, 1.0)
nums = [x[0] for x in results] if aslist:
denums = [x[1] for x in results] if N.all(v == 1):
elif op == self.inverse and (not nclients_check or env.nclients(r) == 1): return []
# num, denum of the second argument are added to the denum, num respectively
nums = [results[0][0], results[1][1]]
denums = [results[0][1], results[1][0]]
elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1):
# num, denum of the sole argument are added to the denum, num respectively
nums = [results[0][1]]
denums = [results[0][0]]
else: else:
return [r], [] return [v]
return v
return reduce(list.__add__, nums), reduce(list.__add__, denums)
num, denum = flatten(r, False)
if (num, denum) == ([r], []):
for input in (follow(r) or []):
canonize(input)
return
# Terms that are both in the num and denum lists cancel each other
for d in list(denum):
if d in list(num):
# list.remove only removes the element once
num.remove(d)
denum.remove(d)
# We identify the constants in num and denum
numct, num = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, num)
denumct, denum = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, denum)
#print numct, num
#print denumct, denum
print num, denum
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
if v != self.neutral:
num.insert(0, C(v))
# We optimize the num and denum lists further if requested
if self.transform is not None:
num, denum = self.transform(env, num, denum)
def make(factors):
# Combines the factors using self.main (aka Mul) depending
# on the number of elements.
n = len(factors)
if n == 0:
return None
elif n == 1:
return factors[0]
else:
return self.main(*factors)
numr, denumr = make(num), make(denum) local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate)
mul_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_mul_canonizer, local_fill_sink), order = 'in_to_out')
if numr is None: def add_calculate(num, denum, aslist = False):
if denumr is None: v = reduce(N.add, num, 0.0) - reduce(N.add, denum, 0.0)
# Everything cancelled each other so we're left with if aslist:
# the neutral element. if N.all(v == 0):
new_r = gof.Constant(r.type, self.neutral) return []
else:
# There's no numerator so we use reciprocal
new_r = self.reciprocal(denumr)
else:
if denumr is None:
new_r = numr
else: else:
new_r = self.inverse(numr, denumr) return [v]
return v
local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate)
add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_fill_sink), order = 'in_to_out')
##################
# Distributivity #
##################
def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore = 0):
score = len(num) + len(denum) # score is number of operations saved, higher is better
new_pos_pairs = itertools.starmap(local_mul_canonizer.simplify,
[(n+num, d+denum) for (n, d) in plus_pairs])
new_neg_pairs = itertools.starmap(local_mul_canonizer.simplify,
[(n+num, d+denum) for (n, d) in plus_pairs])
for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs):
# We calculate how many operations we are saving with the new num and denum
score += len(n) + len(d) - len(nn) - len(dd)
if score < minscore:
return False, pos_pairs, neg_pairs
return True, new_pos_pairs, new_neg_pairs
@gof.local_optimizer
def local_greedy_distributor(node):
"""
The following expressions are simplified:
((a/x + b/y) * x * y) --> a*y + b*x
((a/x + b) * x) --> a + b*x
The following expressions are not:
((a + b) * x) -X-> a*x + b*x
"""
out = node.outputs[0]
num, denum = local_mul_canonizer.get_num_denum(out)
if len(num) == 1 and not denum:
return False
new_num = []
for entry in num:
pos, neg = local_add_canonizer.get_num_denum(entry)
if len(pos) == 1 and not neg:
new_num.append(entry)
continue
pos_pairs = map(local_mul_canonizer.get_num_denum, pos)
neg_pairs = map(local_mul_canonizer.get_num_denum, neg)
# class Canonizer(gof.LocalOptimizer):
# Hopefully this won't complain! # def __init__(self, main, inverse, reciprocal, simplify_constants, constant_op):
env.replace(r, new_r) # self.main = main
# self.inverse = inverse
# self.reciprocal = reciprocal
# self.simplify_constants = simplify_constants
# self.constant_op = constant_op
for factor in num + denum: # def get_num_denum(self, input, depth):
canonize(factor) # if depth == 0 or input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
# return [input], []
# num = []
# denum = []
# parent = input.owner
# pairs = [self.get_num_denum(input, depth - 1) for input in parent.inputs]
# if parent.op == self.main:
# num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
# denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
# elif parent.op == self.inverse:
# num = pairs[0][0] + pairs[1][1]
# denum = pairs[0][1] + pairs[1][0]
# elif parent.op == self.reciprocal:
# num = pairs[0][1]
# denum = pairs[0][0]
# return num, denum
for output in env.outputs: # def deep_num_denum(self, node):
canonize(output) # op = node.op
# if op == self.main:
# num, denum = self.get_num_denum(inputs)
# elif op == self.inverse:
# assert len(inputs) == 2
# n1, d1 = self.get_num_denum(inputs[:1])
# n2, d2 = self.get_num_denum(inputs[1:])
# num, denum = n1+d2, d1+n2
# elif op == self.reciprocal:
# denum, num = self.get_num_denum(inputs)
# else:
# num, denum = [node.outputs[0]], []
# return num, denum
# def get_num_denum(self, inputs):
# num = []
# denum = []
# for input in inputs:
# if input.owner is None:
# num.append(input)
# continue
# parent = input.owner
# if parent.op == self.main:
# num += parent.inputs
# elif parent.op == self.inverse:
# num += parent.inputs[:1]
# denum += parent.inputs[1:]
# elif parent.op == self.reciprocal:
# denum += parent.inputs
# else:
# num.append(input)
# return num, denum
# def merge_num_denum(self, num, denum, outtype):
# ln, ld = len(num), len(denum)
# if not ln and not ld:
# return outtype.filter(self.simplify_constants([], []))
# if not ln:
# return self.reciprocal(self.merge_num_denum(denum, [], outtype))
# if not ld:
# if ln == 1:
# return num[0]
# else:
# return self.main(*num)
# return self.inverse(self.merge_num_denum(num, [], outtype),
# self.merge_num_denum(denum, [], outtype))
# def get_constant(self, v):
# if isinstance(v, gof.Constant):
# return v.data
# if v.owner and isinstance(v.owner.op, DimShuffle):
# return self.get_constant(v.owner.inputs[0])
# return None
# def simplify(self, num, denum):
# numct, denumct = [], []
# ncc, dcc = 0, 0
# for v in list(num):
# if v in denum:
# num.remove(v)
# denum.remove(v)
# continue
# ct = self.get_constant(v)
# if ct is not None:
# ncc += 1
# num.remove(v)
# numct.append(ct)
# for v in list(denum):
# ct = self.get_constant(v)
# if ct is not None:
# dcc += 1
# denum.remove(v)
# denumct.append(ct)
# ct = self.simplify_constants(numct, denumct)
# if ct is None:
# return ncc+dcc>0, None, num, denum
# ctop = self.constant_op.get(ct)
# if ctop is not None:
# return True, ctop, num, denum
# return not (ncc==1 and dcc==0), None, [ct]+num, denum
# def transform(self, node):
# op = node.op
# inputs = node.inputs
# if op == self.main:
# num, denum = self.get_num_denum(inputs)
# elif op == self.inverse:
# assert len(inputs) == 2
# n1, d1 = self.get_num_denum(inputs[:1])
# n2, d2 = self.get_num_denum(inputs[1:])
# num, denum = n1+d2, d1+n2
# elif op == self.reciprocal:
# denum, num = self.get_num_denum(inputs)
# else:
# return False
# change, ctop, num2, denum2 = self.simplify(num, denum)
# if change:
# num, denum = num2, denum2
# # print node, ct, num, denum
# # ctop = ct != [] and self.constant_op.get(ct[0], None)
# # if not ctop:
# # num = ct + num
# new = self.merge_num_denum(num, denum, node.outputs[0].type)
# if ctop:
# new = ctop(new)
# print new.owner.op, op, new.owner.inputs, inputs
# if new.owner and new.owner.op == op and all((new_input.owner new.owner.inputs == inputs:
# return False
# return [new]
# @gof.local_optimizer
# def local_cut_middlemen(node):
# op = node.op
# if isinstance(op, Elemwise):
# aaaaaaa
# # @gof.local_optimizer
# # def local_merge_mul(node):
# # op = node.op
# # if op != mul:
# # return False
# # num, denum = _get_num_denum(node.inputs)
# # if num == node.inputs and denum == []:
# # return False
# # return _
# class Lift(gof.LocalOptimizer):
# def __init__(self, op, lifters, chooser):
# self.op = op
# self.lifters = lifters
# self.chooser = chooser
# def op_key(self):
# return self.op
# def transform(self, node):
# if not node.op == self.op:
# return False
# candidates = [node.inputs[0]]
# seen = set(candidates)
# while True:
# candidate = candidates.pop()
# for lifter in self.lifters:
# new_candidates = lifter(candidate)
# if not new_candidates:
# break
# else:
# candidates.append(candidate)
# new_op = self.op(self.chooser(candidates))
# return new_op
# class Canonizer(gof.Optimizer):
# """
# Simplification tool.
# Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
# * main: a suitable Op class that is commutative, associative and takes
# one to an arbitrary number of inputs, e.g. Add or Mul
# * inverse: an Op class such that inverse(main(x, y), y) == x
# e.g. Sub or Div
# * reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
# e.g. Neg or Inv
# * mainfn, invfn, recfn: functions that behave just like the previous three
# Ops, but on true scalars (e.g. their impl)
# * transform: a function that maps (numerator, denominatur) where numerator
# and denominator are lists of Result instances, to new lists
# where further simplifications may have been applied.
# Examples:
# add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
# mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
# Examples of optimizations mul_canonizer can perform:
# x / x -> 1
# (x * y) / x -> y
# x / y / x -> 1 / y
# x / y / z -> x / (y * z)
# x / (y / z) -> (x * z) / y
# (a / b) * (b / c) * (c / d) -> a / d
# (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
# 2 * x / 2 -> x
# """
# def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
# self.main = main
# self.inverse = inverse
# self.reciprocal = reciprocal
# self.mainfn = mainfn
# self.invfn = invfn
# self.recfn = recfn
# self.neutral = mainfn()
# self.transform = transform
# def apply(self, env):
# def edge(r):
# return r.owner is None
# def follow(r):
# return None if r.owner is None else r.owner.inputs
# def canonize(r):
# next = follow(r)
# if next is None:
# return
# def flatten(r, nclients_check = True):
# # Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# # into a list of numerators and a list of denominators
# # e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
# if edge(r):
# return [r], []
# node = r.owner
# op = node.op
# results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs]
# if op == self.main and (not nclients_check or env.nclients(r) == 1):
# nums = [x[0] for x in results]
# denums = [x[1] for x in results]
# elif op == self.inverse and (not nclients_check or env.nclients(r) == 1):
# # num, denum of the second argument are added to the denum, num respectively
# nums = [results[0][0], results[1][1]]
# denums = [results[0][1], results[1][0]]
# elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1):
# # num, denum of the sole argument are added to the denum, num respectively
# nums = [results[0][1]]
# denums = [results[0][0]]
# else:
# return [r], []
# return reduce(list.__add__, nums), reduce(list.__add__, denums)
# num, denum = flatten(r, False)
# if (num, denum) == ([r], []):
# for input in (follow(r) or []):
# canonize(input)
# return
# # Terms that are both in the num and denum lists cancel each other
# for d in list(denum):
# if d in list(num):
# # list.remove only removes the element once
# num.remove(d)
# denum.remove(d)
# # We identify the constants in num and denum
# numct, num = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, num)
# denumct, denum = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, denum)
# #print numct, num
# #print denumct, denum
# print num, denum
# # All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
# v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
# if v != self.neutral:
# num.insert(0, C(v))
# # We optimize the num and denum lists further if requested
# if self.transform is not None:
# num, denum = self.transform(env, num, denum)
# def make(factors):
# # Combines the factors using self.main (aka Mul) depending
# # on the number of elements.
# n = len(factors)
# if n == 0:
# return None
# elif n == 1:
# return factors[0]
# else:
# return self.main(*factors)
# numr, denumr = make(num), make(denum)
# if numr is None:
# if denumr is None:
# # Everything cancelled each other so we're left with
# # the neutral element.
# new_r = gof.Constant(r.type, self.neutral)
# else:
# # There's no numerator so we use reciprocal
# new_r = self.reciprocal(denumr)
# else:
# if denumr is None:
# new_r = numr
# else:
# new_r = self.inverse(numr, denumr)
# # Hopefully this won't complain!
# env.replace(r, new_r)
# for factor in num + denum:
# canonize(factor)
# for output in env.outputs:
# canonize(output)
_mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs) # _mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
_divfn = lambda x, y: x / y # _divfn = lambda x, y: x / y
_invfn = lambda x: 1 / x # _invfn = lambda x: 1 / x
mul_canonizer = Canonizer(T.mul, T.div, T.inv, _mulfn, _divfn, _invfn) # mul_canonizer = Canonizer(T.mul, T.div, T.inv, _mulfn, _divfn, _invfn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论