提交 11be7be2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

optimizer fiesta

上级 deff95dc
......@@ -3,12 +3,12 @@
import unittest
import gof
from tensor_opt import *
import tensor
from tensor import Tensor
from gof import Env
from elemwise import DimShuffle
from theano import gof
from theano.tensor_opt import *
from theano import tensor
from theano.tensor import Tensor
from theano.gof import Env
from theano.elemwise import DimShuffle
import numpy
#import scalar_opt
......@@ -68,7 +68,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]")
lift_dimshuffle.optimize(g)
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]")
def test_merge2(self):
......@@ -76,7 +76,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
lift_dimshuffle.optimize(g)
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
def test_elim3(self):
......@@ -84,7 +84,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
lift_dimshuffle.optimize(g)
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]", str(g))
def test_lift(self):
......@@ -92,10 +92,134 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = x + y + z
g = Env([x, y, z], [e])
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
lift_dimshuffle.optimize(g)
dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
from theano.tensor import *
from theano.sandbox import pprint
class _test_canonize(unittest.TestCase):
def test_muldiv(self):
x, y, z = matrices('xyz')
a, b, c, d = matrices('abcd')
# e = (2.0 * x) / (2.0 * y)
# e = (2.0 * x) / (4.0 * y)
# e = x / (y / z)
# e = (x * y) / x
# e = (x / y) * (y / z) * (z / x)
# e = (a / b) * (b / c) * (c / d)
# e = (a * b) / (b * c) / (c * d)
# e = 2 * x / 2
# e = x / y / x
e = (x / x) * (y / y)
g = Env([x, y, z, a, b, c, d], [e])
print pprint.pp.process(g.outputs[0])
mul_canonizer.optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g)
print pprint.pp.process(g.outputs[0])
# def test_plusmin(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# # e = x - x
# # e = (2.0 + x) - (2.0 + y)
# # e = (2.0 + x) - (4.0 + y)
# # e = x - (y - z)
# # e = (x + y) - x
# # e = (x - y) + (y - z) + (z - x)
# # e = (a - b) + (b - c) + (c - d)
# # e = x + -y
# # e = a - b - b + a + b + c + b - c
# # e = x + log(y) - x + y
# e = 2.0 + x + 4.0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs: sum(inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_both(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# e0 = (x * y / x)
# e = e0 + e0 - e0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn).optimize(g)
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_group_powers(self):
# x, y, z, a, b, c, d = floats('xyzabcd')
###################
# c1, c2 = constant(1.), constant(2.)
# #e = pow(x, c1) * pow(x, y) / pow(x, 7.0) # <-- fucked
# #f = -- moving from div(mul.out, pow.out) to pow(x, sub.out)
# e = div(mul(pow(x, 2.0), pow(x, y)), pow(x, 7.0))
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# print g.inputs, g.outputs, g.orphans
# f = sub(add(2.0, y), add(7.0))
# g.replace(e, pow(x, f))
# print g
# print g.inputs, g.outputs, g.orphans
# g.replace(f, sub(add(2.0, y), add(7.0))) # -- moving from sub(add.out, add.out) to sub(add.out, add.out)
# print g
# print g.inputs, g.outputs, g.orphans
###################
# # e = x * exp(y) * exp(z)
# # e = x * pow(x, y) * pow(x, z)
# # e = pow(x, y) / pow(x, z)
# e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) # <-- fucked
# # e = pow(x - x, y)
# # e = pow(x, 2.0 + y - 7.0)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z)
# # e = pow(x, 2.0 + y - 7.0 - z)
# # e = x ** y / x ** y
# # e = x ** y / x ** (y - 1.0)
# # e = exp(x) * a * exp(y) / exp(z)
# g = Env([x, y, z, a, b, c, d], [e])
# g.extend(gof.PrintListener(g))
# print g, g.orphans
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(mul, div, inv, mulfn, divfn, invfn, group_powers).optimize(g)
# print g, g.orphans
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(add, sub, neg, addfn, subfn, negfn).optimize(g)
# print g, g.orphans
# pow2one_float.optimize(g)
# pow2x_float.optimize(g)
# print g, g.orphans
# class _test_cliques(unittest.TestCase):
# def test_straightforward(self):
......
......@@ -237,6 +237,7 @@ class Elemwise(Op):
is left-completed to the greatest number of dimensions with 1s
using DimShuffle.
"""
inputs = map(as_tensor, inputs)
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
......@@ -303,11 +304,10 @@ class Elemwise(Op):
if node is None:
# the gradient contains a constant, translate it as
# an equivalent Tensor of size 1 and proper number of dimensions
b = [1] * nd
res = TensorConstant(Tensor(dtype = r.type.dtype,
broadcastable = b),
numpy.asarray(r.data).reshape(b))
return res
broadcastable = ()),
numpy.asarray(r.data)) # .reshape(b)
return DimShuffle((), ['x']*nd, inplace = True)(res)
new_r = Elemwise(node.op, {})(*[transform(input) for input in node.inputs])
return new_r
ret = []
......
......@@ -18,9 +18,9 @@ from op import \
Op
from opt import \
Optimizer, SeqOptimizer, \
Optimizer, optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \
LocalOptimizer, LocalOptGroup, LocalOpKeyOptGroup, \
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \
OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer
......
......@@ -382,17 +382,17 @@ class Env(utils.object2):
"Same as len(self.clients(r))."
return len(self.clients(r))
def edge(self, r):
return r in self.inputs or r in self.orphans
def follow(self, r):
node = r.owner
if self.edge(r):
return None
else:
if node is None:
raise Exception("what the fuck")
return node.inputs
# def edge(self, r):
# return r in self.inputs or r in self.orphans
# def follow(self, r):
# node = r.owner
# if self.edge(r):
# return None
# else:
# if node is None:
# raise Exception("what the fuck")
# return node.inputs
def check_integrity(self):
"""
......
......@@ -56,6 +56,16 @@ class Optimizer:
pass
class FromFunctionOptimizer(Optimizer):
def __init__(self, fn):
self.apply = fn
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
def optimizer(f):
return FromFunctionOptimizer(f)
class SeqOptimizer(Optimizer, list):
"""
......@@ -137,6 +147,7 @@ class MergeOptimizer(Optimizer):
sig = r.signature()
other_r = inv_cid.get(sig, None)
if other_r is not None:
if r.name: other_r.name = r.name
env.replace_validate(r, other_r)
else:
cid[r] = sig
......@@ -155,8 +166,12 @@ class MergeOptimizer(Optimizer):
success = False
if dup is not None:
success = True
pairs = zip(node.outputs, dup.outputs)
for output, new_output in pairs:
if output.name and not new_output.name:
new_output.name = output.name
try:
env.replace_all_validate(zip(node.outputs, dup.outputs))
env.replace_all_validate(pairs)
except InconsistencyError, e:
success = False
if not success:
......@@ -189,17 +204,27 @@ class LocalOptimizer(utils.object2):
raise utils.AbstractFunctionError()
class FromFunctionLocalOptimizer(LocalOptimizer):
def __init__(self, fn):
self.transform = fn
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
def local_optimizer(f):
return FromFunctionLocalOptimizer(f)
class LocalOptGroup(LocalOptimizer):
def __init__(self, optimizers):
def __init__(self, *optimizers):
self.opts = optimizers
self.reentrant = any(getattr(opt, 'reentrant', True), optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False), optimizers)
self.reentrant = any(getattr(opt, 'reentrant', True) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False) for opt in optimizers)
def transform(self, node):
for opt in self.opts:
repl = opt.transform(node)
if repl is not False:
if repl:
return repl
......@@ -547,3 +572,43 @@ class OpKeyOptimizer(NavigatorOptimizer):
def keep_going(exc, nav, repl_pairs):
pass
#################
### Utilities ###
#################
def _check_chain(r, chain):
chain = list(reversed(chain))
while chain:
elem = chain.pop()
if elem is None:
if not r.owner is None:
return False
elif r.owner is None:
return False
elif isinstance(elem, op.Op):
if not r.owner.op == elem:
return False
else:
try:
if issubclass(elem, op.Op) and not isinstance(r.owner.op, elem):
return False
except TypeError:
return False
if chain:
r = r.owner.inputs[chain.pop()]
return r
def check_chain(r, *chain):
if isinstance(r, graph.Apply):
r = r.outputs[0]
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论