提交 11be7be2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

optimizer fiesta

上级 deff95dc
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
import unittest import unittest
import gof from theano import gof
from tensor_opt import * from theano.tensor_opt import *
import tensor from theano import tensor
from tensor import Tensor from theano.tensor import Tensor
from gof import Env from theano.gof import Env
from elemwise import DimShuffle from theano.elemwise import DimShuffle
import numpy import numpy
#import scalar_opt #import scalar_opt
...@@ -68,7 +68,7 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -68,7 +68,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(x, (1, 0)), (1, 0)) e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]") self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]")
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]") self.failUnless(str(g) == "[x]")
def test_merge2(self): def test_merge2(self):
...@@ -76,7 +76,7 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -76,7 +76,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1)) e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g)) self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g)) self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
def test_elim3(self): def test_elim3(self):
...@@ -84,7 +84,7 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -84,7 +84,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0)) e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g)) self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]", str(g)) self.failUnless(str(g) == "[x]", str(g))
def test_lift(self): def test_lift(self):
...@@ -92,10 +92,134 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -92,10 +92,134 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = x + y + z e = x + y + z
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g)) self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g)) self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
from theano.tensor import *
from theano.sandbox import pprint
class _test_canonize(unittest.TestCase):
def test_muldiv(self):
x, y, z = matrices('xyz')
a, b, c, d = matrices('abcd')
# e = (2.0 * x) / (2.0 * y)
# e = (2.0 * x) / (4.0 * y)
# e = x / (y / z)
# e = (x * y) / x
# e = (x / y) * (y / z) * (z / x)
# e = (a / b) * (b / c) * (c / d)
# e = (a * b) / (b * c) / (c * d)
# e = 2 * x / 2
# e = x / y / x
e = (x / x) * (y / y)
g = Env([x, y, z, a, b, c, d], [e])
print pprint.pp.process(g.outputs[0])
mul_canonizer.optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g)
print pprint.pp.process(g.outputs[0])
# def test_plusmin(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# # e = x - x
# # e = (2.0 + x) - (2.0 + y)
# # e = (2.0 + x) - (4.0 + y)
# # e = x - (y - z)
# # e = (x + y) - x
# # e = (x - y) + (y - z) + (z - x)
# # e = (a - b) + (b - c) + (c - d)
# # e = x + -y
# # e = a - b - b + a + b + c + b - c
# # e = x + log(y) - x + y
# e = 2.0 + x + 4.0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs: sum(inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_both(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# e0 = (x * y / x)
# e = e0 + e0 - e0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn).optimize(g)
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_group_powers(self):
# x, y, z, a, b, c, d = floats('xyzabcd')
###################
# c1, c2 = constant(1.), constant(2.)
# #e = pow(x, c1) * pow(x, y) / pow(x, 7.0) # <-- fucked
# #f = -- moving from div(mul.out, pow.out) to pow(x, sub.out)
# e = div(mul(pow(x, 2.0), pow(x, y)), pow(x, 7.0))
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# print g.inputs, g.outputs, g.orphans
# f = sub(add(2.0, y), add(7.0))
# g.replace(e, pow(x, f))
# print g
# print g.inputs, g.outputs, g.orphans
# g.replace(f, sub(add(2.0, y), add(7.0))) # -- moving from sub(add.out, add.out) to sub(add.out, add.out)
# print g
# print g.inputs, g.outputs, g.orphans
###################
# # e = x * exp(y) * exp(z)
# # e = x * pow(x, y) * pow(x, z)
# # e = pow(x, y) / pow(x, z)
# e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) # <-- fucked
# # e = pow(x - x, y)
# # e = pow(x, 2.0 + y - 7.0)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z)
# # e = pow(x, 2.0 + y - 7.0 - z)
# # e = x ** y / x ** y
# # e = x ** y / x ** (y - 1.0)
# # e = exp(x) * a * exp(y) / exp(z)
# g = Env([x, y, z, a, b, c, d], [e])
# g.extend(gof.PrintListener(g))
# print g, g.orphans
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(mul, div, inv, mulfn, divfn, invfn, group_powers).optimize(g)
# print g, g.orphans
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(add, sub, neg, addfn, subfn, negfn).optimize(g)
# print g, g.orphans
# pow2one_float.optimize(g)
# pow2x_float.optimize(g)
# print g, g.orphans
# class _test_cliques(unittest.TestCase): # class _test_cliques(unittest.TestCase):
# def test_straightforward(self): # def test_straightforward(self):
......
...@@ -237,6 +237,7 @@ class Elemwise(Op): ...@@ -237,6 +237,7 @@ class Elemwise(Op):
is left-completed to the greatest number of dimensions with 1s is left-completed to the greatest number of dimensions with 1s
using DimShuffle. using DimShuffle.
""" """
inputs = map(as_tensor, inputs) inputs = map(as_tensor, inputs)
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs]) shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
...@@ -303,11 +304,10 @@ class Elemwise(Op): ...@@ -303,11 +304,10 @@ class Elemwise(Op):
if node is None: if node is None:
# the gradient contains a constant, translate it as # the gradient contains a constant, translate it as
# an equivalent Tensor of size 1 and proper number of dimensions # an equivalent Tensor of size 1 and proper number of dimensions
b = [1] * nd
res = TensorConstant(Tensor(dtype = r.type.dtype, res = TensorConstant(Tensor(dtype = r.type.dtype,
broadcastable = b), broadcastable = ()),
numpy.asarray(r.data).reshape(b)) numpy.asarray(r.data)) # .reshape(b)
return res return DimShuffle((), ['x']*nd, inplace = True)(res)
new_r = Elemwise(node.op, {})(*[transform(input) for input in node.inputs]) new_r = Elemwise(node.op, {})(*[transform(input) for input in node.inputs])
return new_r return new_r
ret = [] ret = []
......
...@@ -18,9 +18,9 @@ from op import \ ...@@ -18,9 +18,9 @@ from op import \
Op Op
from opt import \ from opt import \
Optimizer, SeqOptimizer, \ Optimizer, optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \ MergeOptimizer, MergeOptMerge, \
LocalOptimizer, LocalOptGroup, LocalOpKeyOptGroup, \ LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \
OpSub, OpRemove, PatternSub, \ OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer
......
...@@ -382,17 +382,17 @@ class Env(utils.object2): ...@@ -382,17 +382,17 @@ class Env(utils.object2):
"Same as len(self.clients(r))." "Same as len(self.clients(r))."
return len(self.clients(r)) return len(self.clients(r))
def edge(self, r): # def edge(self, r):
return r in self.inputs or r in self.orphans # return r in self.inputs or r in self.orphans
def follow(self, r): # def follow(self, r):
node = r.owner # node = r.owner
if self.edge(r): # if self.edge(r):
return None # return None
else: # else:
if node is None: # if node is None:
raise Exception("what the fuck") # raise Exception("what the fuck")
return node.inputs # return node.inputs
def check_integrity(self): def check_integrity(self):
""" """
......
...@@ -56,6 +56,16 @@ class Optimizer: ...@@ -56,6 +56,16 @@ class Optimizer:
pass pass
class FromFunctionOptimizer(Optimizer):
def __init__(self, fn):
self.apply = fn
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
def optimizer(f):
return FromFunctionOptimizer(f)
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
""" """
...@@ -137,6 +147,7 @@ class MergeOptimizer(Optimizer): ...@@ -137,6 +147,7 @@ class MergeOptimizer(Optimizer):
sig = r.signature() sig = r.signature()
other_r = inv_cid.get(sig, None) other_r = inv_cid.get(sig, None)
if other_r is not None: if other_r is not None:
if r.name: other_r.name = r.name
env.replace_validate(r, other_r) env.replace_validate(r, other_r)
else: else:
cid[r] = sig cid[r] = sig
...@@ -155,8 +166,12 @@ class MergeOptimizer(Optimizer): ...@@ -155,8 +166,12 @@ class MergeOptimizer(Optimizer):
success = False success = False
if dup is not None: if dup is not None:
success = True success = True
pairs = zip(node.outputs, dup.outputs)
for output, new_output in pairs:
if output.name and not new_output.name:
new_output.name = output.name
try: try:
env.replace_all_validate(zip(node.outputs, dup.outputs)) env.replace_all_validate(pairs)
except InconsistencyError, e: except InconsistencyError, e:
success = False success = False
if not success: if not success:
...@@ -189,17 +204,27 @@ class LocalOptimizer(utils.object2): ...@@ -189,17 +204,27 @@ class LocalOptimizer(utils.object2):
raise utils.AbstractFunctionError() raise utils.AbstractFunctionError()
class FromFunctionLocalOptimizer(LocalOptimizer):
def __init__(self, fn):
self.transform = fn
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
def local_optimizer(f):
return FromFunctionLocalOptimizer(f)
class LocalOptGroup(LocalOptimizer): class LocalOptGroup(LocalOptimizer):
def __init__(self, optimizers): def __init__(self, *optimizers):
self.opts = optimizers self.opts = optimizers
self.reentrant = any(getattr(opt, 'reentrant', True), optimizers) self.reentrant = any(getattr(opt, 'reentrant', True) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False), optimizers) self.retains_inputs = all(getattr(opt, 'retains_inputs', False) for opt in optimizers)
def transform(self, node): def transform(self, node):
for opt in self.opts: for opt in self.opts:
repl = opt.transform(node) repl = opt.transform(node)
if repl is not False: if repl:
return repl return repl
...@@ -547,3 +572,43 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -547,3 +572,43 @@ class OpKeyOptimizer(NavigatorOptimizer):
def keep_going(exc, nav, repl_pairs): def keep_going(exc, nav, repl_pairs):
pass pass
#################
### Utilities ###
#################
def _check_chain(r, chain):
chain = list(reversed(chain))
while chain:
elem = chain.pop()
if elem is None:
if not r.owner is None:
return False
elif r.owner is None:
return False
elif isinstance(elem, op.Op):
if not r.owner.op == elem:
return False
else:
try:
if issubclass(elem, op.Op) and not isinstance(r.owner.op, elem):
return False
except TypeError:
return False
if chain:
r = r.owner.inputs[chain.pop()]
return r
def check_chain(r, *chain):
if isinstance(r, graph.Apply):
r = r.outputs[0]
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
import gof import gof
from gof import opt
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
import scalar import scalar
import tensor as T import tensor as T
import numpy as N
import operator
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0) # Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
...@@ -26,7 +29,8 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'), ...@@ -26,7 +29,8 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
allow_multiple_clients = False) allow_multiple_clients = False)
class InplaceOptimizer(gof.Optimizer): @gof.optimizer
def inplace_optimizer(self, env):
""" """
Usage: inplace_optimizer.optimize(env) Usage: inplace_optimizer.optimize(env)
...@@ -40,39 +44,34 @@ class InplaceOptimizer(gof.Optimizer): ...@@ -40,39 +44,34 @@ class InplaceOptimizer(gof.Optimizer):
x + y + z -> x += y += z x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y) (x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
""" """
for node in list(env.nodes):
def apply(self, env): op = node.op
for node in list(env.nodes): if not isinstance(op, Elemwise):
op = node.op continue
if not isinstance(op, Elemwise): baseline = op.inplace_pattern
continue candidate_outputs = [i for i in xrange(len(node.outputs)) if i not in baseline]
baseline = op.inplace_pattern candidate_inputs = [i for i in xrange(len(node.inputs)) if i not in baseline.values()]
candidate_outputs = [i for i in xrange(len(node.outputs)) if i not in baseline] for candidate_output in candidate_outputs:
candidate_inputs = [i for i in xrange(len(node.inputs)) if i not in baseline.values()] for candidate_input in candidate_inputs:
for candidate_output in candidate_outputs: inplace_pattern = dict(baseline, **{candidate_output: candidate_input})
for candidate_input in candidate_inputs: try:
inplace_pattern = dict(baseline, **{candidate_output: candidate_input}) new = Elemwise(op.scalar_op, inplace_pattern).make_node(op.inputs)
try: env.replace_all_validate(dict(zip(node.outputs, new.outputs)))
new = Elemwise(op.scalar_op, inplace_pattern).make_node(op.inputs) except:
env.replace_all_validate(dict(zip(node.outputs, new.outputs))) continue
except: candidate_inputs.remove(candidate_input)
continue node = new
candidate_inputs.remove(candidate_input) baseline = inplace_pattern
node = new break
baseline = inplace_pattern
break
######################
def add_requirements(self, env): # DimShuffle lifters #
env.extend(gof.toolbox.ReplaceValidate) ######################
inplace_optimizer = InplaceOptimizer() @gof.local_optimizer
def local_dimshuffle_lift(node):
class DimShuffleLifter(gof.LocalOptimizer):
""" """
Usage: lift_dimshuffle.optimize(env)
"Lifts" DimShuffle through Elemwise operations and merges "Lifts" DimShuffle through Elemwise operations and merges
consecutive DimShuffles. Basically, applies the following consecutive DimShuffles. Basically, applies the following
transformations on the whole graph: transformations on the whole graph:
...@@ -83,188 +82,824 @@ class DimShuffleLifter(gof.LocalOptimizer): ...@@ -83,188 +82,824 @@ class DimShuffleLifter(gof.LocalOptimizer):
After this transform, clusters of Elemwise operations are After this transform, clusters of Elemwise operations are
void of DimShuffle operations. void of DimShuffle operations.
""" """
op = node.op
if not isinstance(op, DimShuffle):
return False
input = node.inputs[0]
inode = input.owner
if inode and isinstance(inode.op, Elemwise):
return inode.op.make_node(*[DimShuffle(input.type.broadcastable,
op.new_order,
op.inplace)(input) for input in inode.inputs]).outputs
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in op.new_order]
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
if new_order == range(len(new_order)):
return [iinput]
else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
dimshuffle_lift = gof.TopoOptimizer(local_dimshuffle_lift, order = 'out_to_in')
#################
# Shape lifters #
#################
@gof.local_optimizer
def local_shape_lift_elemwise(node):
"""
shape(elemwise_op(..., x, ...)) -> shape(x)
def transform(self, node): Where x contains the maximal shape information.
op = node.op """
if not isinstance(op, DimShuffle): if not opt.check_chain(node, T.shape, T.Elemwise):
return False return False
output = node.inputs[0]
parent = output.owner
input = node.inputs[0] for input in parent.inputs:
inode = input.owner if input.type.broadcastable == output.type.broadcastable:
if inode and isinstance(inode.op, Elemwise): return T.shape.make_node(input).outputs
return inode.op.make_node(*[DimShuffle(input.type.broadcastable,
op.new_order, return False
op.inplace)(input) for input in inode.inputs]).outputs
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in op.new_order]
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
if new_order == range(len(new_order)):
return [iinput]
else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
lift_dimshuffle = gof.TopoOptimizer(DimShuffleLifter(), order = 'out_to_in') @gof.local_optimizer
def local_shape_lift_sum(node):
"""
shape(sum{n}(x)) -> [shape(x)[0], ..., shape(x)[n-1], shape(x)[n+1], ...]
"""
if not opt.check_chain(node, T.shape, T.Sum):
return False
input = node.inputs[0].owner.inputs[0]
axis = node.inputs[0].owner.op.axis
if axis is None:# or len(axis) != 1:
axis = range(input.type.ndim)
ish = T.shape(input)
return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs
# return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs
class Canonizer(gof.Optimizer): @gof.local_optimizer
def local_shape_lift_dot(node):
""" """
Simplification tool. shape(dot(a, b)) -> [shape(a)[0], shape(b)[1]]
"""
if not opt.check_chain(node, T.shape, T.dot):
return False
a, b = node.inputs[0].owner.inputs
return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs
Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform) local_shape_lift = opt.LocalOptGroup(local_shape_lift_elemwise,
local_shape_lift_sum,
* main: a suitable Op class that is commutative, associative and takes local_shape_lift_dot)
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* mainfn, invfn, recfn: functions that behave just like the previous three
Ops, but on true scalars (e.g. their impl)
* transform: a function that maps (numerator, denominatur) where numerator ################
and denominator are lists of Result instances, to new lists # Fill lifters #
where further simplifications may have been applied. ################
Examples: def encompasses_broadcastable(b1, b2):
add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...) if len(b1) < len(b2):
mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...) return False
b1 = b1[-len(b2):]
Examples of optimizations mul_canonizer can perform: return not any(v1 and not v2 for v1, v2 in zip(b1, b2))
x / x -> 1
(x * y) / x -> y def merge_broadcastables(broadcastables):
x / y / x -> 1 / y return [all(bcast) for bcast in zip(*broadcastables)]
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y @gof.local_optimizer
(a / b) * (b / c) * (c / d) -> a / d def local_fill_lift(node):
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
""" """
fill(f(a), b) -> fill(a, b)
If a.type == f(a).type.
fill(a, b) -> b
If a.type == b.type.
"""
if not opt.check_chain(node, T.fill):
return False
model, filling = node.inputs
def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None): mb, fb = model.type.broadcastable, filling.type.broadcastable
if model.type.dtype == filling.type.dtype and encompasses_broadcastable(fb, mb):
return [filling]
parent = model.owner
if parent is None:
return False
for input in parent.inputs:
if input.type == model.type:
return [T.fill(input, filling)]
return False
##################
# Subtensor opts #
##################
@gof.local_optimizer
def local_subtensor_make_vector(node):
"""
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
If the index or slice is constant.
"""
if not opt.check_chain(node, T.Subtensor, T.MakeVector):
return False
idxlist = node.op.idx_list
if len(idxlist) != 1:
return False
idx = idxlist[0]
if isinstance(idx, int):
return [node.inputs[0].owner.inputs[idx]]
try:
return T.make_vector(*(node.owner.inputs[0].owner.inputs.__getslice__(idx)))
except TypeError:
return False
##################
# Middleman cuts #
##################
@gof.local_optimizer
def local_fill_cut(node):
"""
f(fill(a,b), c) -> f(b, c)
If c.type == a.type.
"""
if not opt.check_chain(node, T.Elemwise):
return False
output = node.outputs[0]
try:
reference = [input
for input in node.inputs
if input.type == output.type and (not input.owner or input.owner.op != T.fill)][0]
except IndexError:
return False
new_inputs = []
for input in node.inputs:
if opt.check_chain(input, T.fill):
model, filling = input.owner.inputs
if encompasses_broadcastable(reference.type.broadcastable,
filling.type.broadcastable):
new_inputs.append(filling)
continue
new_inputs.append(input)
if new_inputs == node.inputs:
return False
return node.op.make_node(*new_inputs).outputs
@gof.local_optimizer
def local_fill_sink(node):
"""
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
"""
if not (node.op and isinstance(node.op, T.Elemwise) and node.op != T.fill):
return False
models = []
inputs = []
for input in node.inputs:
if input.owner and input.owner.op == T.fill:
models.append(input.owner.inputs[0])
inputs.append(input.owner.inputs[1])
else:
inputs.append(input)
if inputs == node.inputs:
return False
c = node.op(*inputs)
for model in models:
c = T.fill(model, c)
return [c]
################
# Canonization #
################
class Canonizer(gof.LocalOptimizer):
def __init__(self, main, inverse, reciprocal, calculate):
self.main = main self.main = main
self.inverse = inverse self.inverse = inverse
self.reciprocal = reciprocal self.reciprocal = reciprocal
self.mainfn = mainfn self.calculate = calculate
self.invfn = invfn
self.recfn = recfn def get_num_denum(self, input):
self.neutral = mainfn() if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
self.transform = transform return [input], []
num = []
denum = []
parent = input.owner
pairs = [self.get_num_denum(input) for input in parent.inputs]
if parent.op == self.main:
num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
elif parent.op == self.inverse:
num = pairs[0][0] + pairs[1][1]
denum = pairs[0][1] + pairs[1][0]
elif parent.op == self.reciprocal:
num = pairs[0][1]
denum = pairs[0][0]
return num, denum
def merge_num_denum(self, num, denum):
ln, ld = len(num), len(denum)
if not ln and not ld:
return T.as_tensor(self.calculate([], []))
if not ln:
return self.reciprocal(self.merge_num_denum(denum, []))
if not ld:
if ln == 1:
return num[0]
else:
return self.main(*num)
return self.inverse(self.merge_num_denum(num, []),
self.merge_num_denum(denum, []))
def get_constant(self, v):
if isinstance(v, gof.Constant):
return v.data
if v.owner and isinstance(v.owner.op, DimShuffle):
return self.get_constant(v.owner.inputs[0])
return None
def simplify(self, num, denum):
return self.simplify_constants(*self.simplify_factors(num, denum))
def simplify_factors(self, num, denum):
for v in list(num):
if v in denum:
num.remove(v)
denum.remove(v)
return num, denum
def simplify_constants(self, orig_num, orig_denum):
num, denum = list(orig_num), list(orig_denum)
numct, denumct = [], []
ncc, dcc = 0, 0
for v in orig_num:
ct = self.get_constant(v)
if ct is not None:
ncc += 1
num.remove(v)
numct.append(ct)
for v in orig_denum:
ct = self.get_constant(v)
if ct is not None:
dcc += 1
denum.remove(v)
denumct.append(ct)
ct = self.calculate(numct, denumct, aslist = True)
if len(ct) and ncc == 1 and dcc == 0:
return orig_num, orig_denum
return ct + num, denum
def apply(self, env): def transform(self, node):
op = node.op
inputs = node.inputs
out = node.outputs[0]
if op not in [self.main, self.inverse, self.reciprocal]:
return False
iops = set(input.owner.op for input in inputs if input.owner)
reorg = False
if op == self.main:
reorg = len(iops.intersection([self.main, self.inverse, self.reciprocal])) != 0
elif op == self.inverse:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = list(orig_num), list(orig_denum)
num, denum = self.simplify(num, denum)
if not reorg and orig_num == num and orig_denum == denum:
return False
def edge(r): new = self.merge_num_denum(num, denum)
return r.owner is None if new.type != out.type:
def follow(r): new = T.fill(out, new)
return None if r.owner is None else r.owner.inputs return [new]
def mul_calculate(num, denum, aslist = False):
v = reduce(N.multiply, num, 1.0) / reduce(N.multiply, denum, 1.0)
if aslist:
if N.all(v == 1):
return []
else:
return [v]
return v
local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate)
mul_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_mul_canonizer, local_fill_sink), order = 'in_to_out')
def add_calculate(num, denum, aslist = False):
v = reduce(N.add, num, 0.0) - reduce(N.add, denum, 0.0)
if aslist:
if N.all(v == 0):
return []
else:
return [v]
return v
local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate)
add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_fill_sink), order = 'in_to_out')
##################
# Distributivity #
##################
def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore = 0):
score = len(num) + len(denum) # score is number of operations saved, higher is better
new_pos_pairs = itertools.starmap(local_mul_canonizer.simplify,
[(n+num, d+denum) for (n, d) in plus_pairs])
new_neg_pairs = itertools.starmap(local_mul_canonizer.simplify,
[(n+num, d+denum) for (n, d) in plus_pairs])
for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs):
# We calculate how many operations we are saving with the new num and denum
score += len(n) + len(d) - len(nn) - len(dd)
if score < minscore:
return False, pos_pairs, neg_pairs
return True, new_pos_pairs, new_neg_pairs
@gof.local_optimizer
def local_greedy_distributor(node):
"""
The following expressions are simplified:
((a/x + b/y) * x * y) --> a*y + b*x
((a/x + b) * x) --> a + b*x
def canonize(r): The following expressions are not:
((a + b) * x) -X-> a*x + b*x
"""
out = node.outputs[0]
num, denum = local_mul_canonizer.get_num_denum(out)
if len(num) == 1 and not denum:
return False
new_num = []
for entry in num:
pos, neg = local_add_canonizer.get_num_denum(entry)
if len(pos) == 1 and not neg:
new_num.append(entry)
continue
pos_pairs = map(local_mul_canonizer.get_num_denum, pos)
neg_pairs = map(local_mul_canonizer.get_num_denum, neg)
next = follow(r)
if next is None:
return
# class Canonizer(gof.LocalOptimizer):
# def __init__(self, main, inverse, reciprocal, simplify_constants, constant_op):
# self.main = main
# self.inverse = inverse
# self.reciprocal = reciprocal
# self.simplify_constants = simplify_constants
# self.constant_op = constant_op
# def get_num_denum(self, input, depth):
# if depth == 0 or input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
# return [input], []
# num = []
# denum = []
# parent = input.owner
# pairs = [self.get_num_denum(input, depth - 1) for input in parent.inputs]
# if parent.op == self.main:
# num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
# denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
# elif parent.op == self.inverse:
# num = pairs[0][0] + pairs[1][1]
# denum = pairs[0][1] + pairs[1][0]
# elif parent.op == self.reciprocal:
# num = pairs[0][1]
# denum = pairs[0][0]
# return num, denum
# def deep_num_denum(self, node):
# op = node.op
# if op == self.main:
# num, denum = self.get_num_denum(inputs)
# elif op == self.inverse:
# assert len(inputs) == 2
# n1, d1 = self.get_num_denum(inputs[:1])
# n2, d2 = self.get_num_denum(inputs[1:])
# num, denum = n1+d2, d1+n2
# elif op == self.reciprocal:
# denum, num = self.get_num_denum(inputs)
# else:
# num, denum = [node.outputs[0]], []
# return num, denum
# def get_num_denum(self, inputs):
# num = []
# denum = []
# for input in inputs:
# if input.owner is None:
# num.append(input)
# continue
# parent = input.owner
# if parent.op == self.main:
# num += parent.inputs
# elif parent.op == self.inverse:
# num += parent.inputs[:1]
# denum += parent.inputs[1:]
# elif parent.op == self.reciprocal:
# denum += parent.inputs
# else:
# num.append(input)
# return num, denum
# def merge_num_denum(self, num, denum, outtype):
# ln, ld = len(num), len(denum)
# if not ln and not ld:
# return outtype.filter(self.simplify_constants([], []))
# if not ln:
# return self.reciprocal(self.merge_num_denum(denum, [], outtype))
# if not ld:
# if ln == 1:
# return num[0]
# else:
# return self.main(*num)
# return self.inverse(self.merge_num_denum(num, [], outtype),
# self.merge_num_denum(denum, [], outtype))
# def get_constant(self, v):
# if isinstance(v, gof.Constant):
# return v.data
# if v.owner and isinstance(v.owner.op, DimShuffle):
# return self.get_constant(v.owner.inputs[0])
# return None
# def simplify(self, num, denum):
# numct, denumct = [], []
# ncc, dcc = 0, 0
# for v in list(num):
# if v in denum:
# num.remove(v)
# denum.remove(v)
# continue
# ct = self.get_constant(v)
# if ct is not None:
# ncc += 1
# num.remove(v)
# numct.append(ct)
# for v in list(denum):
# ct = self.get_constant(v)
# if ct is not None:
# dcc += 1
# denum.remove(v)
# denumct.append(ct)
# ct = self.simplify_constants(numct, denumct)
# if ct is None:
# return ncc+dcc>0, None, num, denum
# ctop = self.constant_op.get(ct)
# if ctop is not None:
# return True, ctop, num, denum
# return not (ncc==1 and dcc==0), None, [ct]+num, denum
# def transform(self, node):
# op = node.op
# inputs = node.inputs
# if op == self.main:
# num, denum = self.get_num_denum(inputs)
# elif op == self.inverse:
# assert len(inputs) == 2
# n1, d1 = self.get_num_denum(inputs[:1])
# n2, d2 = self.get_num_denum(inputs[1:])
# num, denum = n1+d2, d1+n2
# elif op == self.reciprocal:
# denum, num = self.get_num_denum(inputs)
# else:
# return False
# change, ctop, num2, denum2 = self.simplify(num, denum)
# if change:
# num, denum = num2, denum2
# # print node, ct, num, denum
# # ctop = ct != [] and self.constant_op.get(ct[0], None)
# # if not ctop:
# # num = ct + num
# new = self.merge_num_denum(num, denum, node.outputs[0].type)
# if ctop:
# new = ctop(new)
# print new.owner.op, op, new.owner.inputs, inputs
# if new.owner and new.owner.op == op and all((new_input.owner new.owner.inputs == inputs:
# return False
# return [new]
# @gof.local_optimizer
# def local_cut_middlemen(node):
# op = node.op
# if isinstance(op, Elemwise):
# aaaaaaa
# # @gof.local_optimizer
# # def local_merge_mul(node):
# # op = node.op
# # if op != mul:
# # return False
# # num, denum = _get_num_denum(node.inputs)
# # if num == node.inputs and denum == []:
# # return False
# # return _
# class Lift(gof.LocalOptimizer):
# def __init__(self, op, lifters, chooser):
# self.op = op
# self.lifters = lifters
# self.chooser = chooser
# def op_key(self):
# return self.op
# def transform(self, node):
# if not node.op == self.op:
# return False
# candidates = [node.inputs[0]]
# seen = set(candidates)
# while True:
# candidate = candidates.pop()
# for lifter in self.lifters:
# new_candidates = lifter(candidate)
# if not new_candidates:
# break
# else:
# candidates.append(candidate)
# new_op = self.op(self.chooser(candidates))
# return new_op
# class Canonizer(gof.Optimizer):
# """
# Simplification tool.
# Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
# * main: a suitable Op class that is commutative, associative and takes
# one to an arbitrary number of inputs, e.g. Add or Mul
# * inverse: an Op class such that inverse(main(x, y), y) == x
# e.g. Sub or Div
# * reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
# e.g. Neg or Inv
# * mainfn, invfn, recfn: functions that behave just like the previous three
# Ops, but on true scalars (e.g. their impl)
# * transform: a function that maps (numerator, denominatur) where numerator
# and denominator are lists of Result instances, to new lists
# where further simplifications may have been applied.
# Examples:
# add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
# mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
# Examples of optimizations mul_canonizer can perform:
# x / x -> 1
# (x * y) / x -> y
# x / y / x -> 1 / y
# x / y / z -> x / (y * z)
# x / (y / z) -> (x * z) / y
# (a / b) * (b / c) * (c / d) -> a / d
# (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
# 2 * x / 2 -> x
# """
# def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
# self.main = main
# self.inverse = inverse
# self.reciprocal = reciprocal
# self.mainfn = mainfn
# self.invfn = invfn
# self.recfn = recfn
# self.neutral = mainfn()
# self.transform = transform
# def apply(self, env):
# def edge(r):
# return r.owner is None
# def follow(r):
# return None if r.owner is None else r.owner.inputs
# def canonize(r):
# next = follow(r)
# if next is None:
# return
def flatten(r, nclients_check = True): # def flatten(r, nclients_check = True):
# Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg) # # Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# into a list of numerators and a list of denominators # # into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y] # # e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
if edge(r): # if edge(r):
return [r], [] # return [r], []
node = r.owner # node = r.owner
op = node.op # op = node.op
results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs] # results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs]
if op == self.main and (not nclients_check or env.nclients(r) == 1): # if op == self.main and (not nclients_check or env.nclients(r) == 1):
nums = [x[0] for x in results] # nums = [x[0] for x in results]
denums = [x[1] for x in results] # denums = [x[1] for x in results]
elif op == self.inverse and (not nclients_check or env.nclients(r) == 1): # elif op == self.inverse and (not nclients_check or env.nclients(r) == 1):
# num, denum of the second argument are added to the denum, num respectively # # num, denum of the second argument are added to the denum, num respectively
nums = [results[0][0], results[1][1]] # nums = [results[0][0], results[1][1]]
denums = [results[0][1], results[1][0]] # denums = [results[0][1], results[1][0]]
elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1): # elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1):
# num, denum of the sole argument are added to the denum, num respectively # # num, denum of the sole argument are added to the denum, num respectively
nums = [results[0][1]] # nums = [results[0][1]]
denums = [results[0][0]] # denums = [results[0][0]]
else: # else:
return [r], [] # return [r], []
return reduce(list.__add__, nums), reduce(list.__add__, denums) # return reduce(list.__add__, nums), reduce(list.__add__, denums)
num, denum = flatten(r, False) # num, denum = flatten(r, False)
if (num, denum) == ([r], []): # if (num, denum) == ([r], []):
for input in (follow(r) or []): # for input in (follow(r) or []):
canonize(input) # canonize(input)
return # return
# Terms that are both in the num and denum lists cancel each other # # Terms that are both in the num and denum lists cancel each other
for d in list(denum): # for d in list(denum):
if d in list(num): # if d in list(num):
# list.remove only removes the element once # # list.remove only removes the element once
num.remove(d) # num.remove(d)
denum.remove(d) # denum.remove(d)
# We identify the constants in num and denum # # We identify the constants in num and denum
numct, num = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, num) # numct, num = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, num)
denumct, denum = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, denum) # denumct, denum = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, denum)
#print numct, num # #print numct, num
#print denumct, denum # #print denumct, denum
print num, denum # print num, denum
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant) # # All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct])) # v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
if v != self.neutral: # if v != self.neutral:
num.insert(0, C(v)) # num.insert(0, C(v))
# We optimize the num and denum lists further if requested # # We optimize the num and denum lists further if requested
if self.transform is not None: # if self.transform is not None:
num, denum = self.transform(env, num, denum) # num, denum = self.transform(env, num, denum)
def make(factors): # def make(factors):
# Combines the factors using self.main (aka Mul) depending # # Combines the factors using self.main (aka Mul) depending
# on the number of elements. # # on the number of elements.
n = len(factors) # n = len(factors)
if n == 0: # if n == 0:
return None # return None
elif n == 1: # elif n == 1:
return factors[0] # return factors[0]
else: # else:
return self.main(*factors) # return self.main(*factors)
numr, denumr = make(num), make(denum) # numr, denumr = make(num), make(denum)
if numr is None: # if numr is None:
if denumr is None: # if denumr is None:
# Everything cancelled each other so we're left with # # Everything cancelled each other so we're left with
# the neutral element. # # the neutral element.
new_r = gof.Constant(r.type, self.neutral) # new_r = gof.Constant(r.type, self.neutral)
else: # else:
# There's no numerator so we use reciprocal # # There's no numerator so we use reciprocal
new_r = self.reciprocal(denumr) # new_r = self.reciprocal(denumr)
else: # else:
if denumr is None: # if denumr is None:
new_r = numr # new_r = numr
else: # else:
new_r = self.inverse(numr, denumr) # new_r = self.inverse(numr, denumr)
# Hopefully this won't complain! # # Hopefully this won't complain!
env.replace(r, new_r) # env.replace(r, new_r)
for factor in num + denum: # for factor in num + denum:
canonize(factor) # canonize(factor)
for output in env.outputs: # for output in env.outputs:
canonize(output) # canonize(output)
_mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs) # _mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
_divfn = lambda x, y: x / y # _divfn = lambda x, y: x / y
_invfn = lambda x: 1 / x # _invfn = lambda x: 1 / x
mul_canonizer = Canonizer(T.mul, T.div, T.inv, _mulfn, _divfn, _invfn) # mul_canonizer = Canonizer(T.mul, T.div, T.inv, _mulfn, _divfn, _invfn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论