提交 11be7be2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

optimizer fiesta

上级 deff95dc
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
import unittest import unittest
import gof from theano import gof
from tensor_opt import * from theano.tensor_opt import *
import tensor from theano import tensor
from tensor import Tensor from theano.tensor import Tensor
from gof import Env from theano.gof import Env
from elemwise import DimShuffle from theano.elemwise import DimShuffle
import numpy import numpy
#import scalar_opt #import scalar_opt
...@@ -68,7 +68,7 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -68,7 +68,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(x, (1, 0)), (1, 0)) e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]") self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]")
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]") self.failUnless(str(g) == "[x]")
def test_merge2(self): def test_merge2(self):
...@@ -76,7 +76,7 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -76,7 +76,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1)) e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g)) self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g)) self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
def test_elim3(self): def test_elim3(self):
...@@ -84,7 +84,7 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -84,7 +84,7 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0)) e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e]) g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g)) self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[x]", str(g)) self.failUnless(str(g) == "[x]", str(g))
def test_lift(self): def test_lift(self):
...@@ -92,10 +92,134 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -92,10 +92,134 @@ class _test_dimshuffle_lift(unittest.TestCase):
e = x + y + z e = x + y + z
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g)) self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
lift_dimshuffle.optimize(g) dimshuffle_lift.optimize(g)
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g)) self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
from theano.tensor import *
from theano.sandbox import pprint
class _test_canonize(unittest.TestCase):
def test_muldiv(self):
x, y, z = matrices('xyz')
a, b, c, d = matrices('abcd')
# e = (2.0 * x) / (2.0 * y)
# e = (2.0 * x) / (4.0 * y)
# e = x / (y / z)
# e = (x * y) / x
# e = (x / y) * (y / z) * (z / x)
# e = (a / b) * (b / c) * (c / d)
# e = (a * b) / (b * c) / (c * d)
# e = 2 * x / 2
# e = x / y / x
e = (x / x) * (y / y)
g = Env([x, y, z, a, b, c, d], [e])
print pprint.pp.process(g.outputs[0])
mul_canonizer.optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g)
print pprint.pp.process(g.outputs[0])
# def test_plusmin(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# # e = x - x
# # e = (2.0 + x) - (2.0 + y)
# # e = (2.0 + x) - (4.0 + y)
# # e = x - (y - z)
# # e = (x + y) - x
# # e = (x - y) + (y - z) + (z - x)
# # e = (a - b) + (b - c) + (c - d)
# # e = x + -y
# # e = a - b - b + a + b + c + b - c
# # e = x + log(y) - x + y
# e = 2.0 + x + 4.0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs: sum(inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_both(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# e0 = (x * y / x)
# e = e0 + e0 - e0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn).optimize(g)
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_group_powers(self):
# x, y, z, a, b, c, d = floats('xyzabcd')
###################
# c1, c2 = constant(1.), constant(2.)
# #e = pow(x, c1) * pow(x, y) / pow(x, 7.0) # <-- fucked
# #f = -- moving from div(mul.out, pow.out) to pow(x, sub.out)
# e = div(mul(pow(x, 2.0), pow(x, y)), pow(x, 7.0))
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# print g.inputs, g.outputs, g.orphans
# f = sub(add(2.0, y), add(7.0))
# g.replace(e, pow(x, f))
# print g
# print g.inputs, g.outputs, g.orphans
# g.replace(f, sub(add(2.0, y), add(7.0))) # -- moving from sub(add.out, add.out) to sub(add.out, add.out)
# print g
# print g.inputs, g.outputs, g.orphans
###################
# # e = x * exp(y) * exp(z)
# # e = x * pow(x, y) * pow(x, z)
# # e = pow(x, y) / pow(x, z)
# e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) # <-- fucked
# # e = pow(x - x, y)
# # e = pow(x, 2.0 + y - 7.0)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z)
# # e = pow(x, 2.0 + y - 7.0 - z)
# # e = x ** y / x ** y
# # e = x ** y / x ** (y - 1.0)
# # e = exp(x) * a * exp(y) / exp(z)
# g = Env([x, y, z, a, b, c, d], [e])
# g.extend(gof.PrintListener(g))
# print g, g.orphans
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(mul, div, inv, mulfn, divfn, invfn, group_powers).optimize(g)
# print g, g.orphans
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(add, sub, neg, addfn, subfn, negfn).optimize(g)
# print g, g.orphans
# pow2one_float.optimize(g)
# pow2x_float.optimize(g)
# print g, g.orphans
# class _test_cliques(unittest.TestCase): # class _test_cliques(unittest.TestCase):
# def test_straightforward(self): # def test_straightforward(self):
......
...@@ -237,6 +237,7 @@ class Elemwise(Op): ...@@ -237,6 +237,7 @@ class Elemwise(Op):
is left-completed to the greatest number of dimensions with 1s is left-completed to the greatest number of dimensions with 1s
using DimShuffle. using DimShuffle.
""" """
inputs = map(as_tensor, inputs) inputs = map(as_tensor, inputs)
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs]) shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
...@@ -303,11 +304,10 @@ class Elemwise(Op): ...@@ -303,11 +304,10 @@ class Elemwise(Op):
if node is None: if node is None:
# the gradient contains a constant, translate it as # the gradient contains a constant, translate it as
# an equivalent Tensor of size 1 and proper number of dimensions # an equivalent Tensor of size 1 and proper number of dimensions
b = [1] * nd
res = TensorConstant(Tensor(dtype = r.type.dtype, res = TensorConstant(Tensor(dtype = r.type.dtype,
broadcastable = b), broadcastable = ()),
numpy.asarray(r.data).reshape(b)) numpy.asarray(r.data)) # .reshape(b)
return res return DimShuffle((), ['x']*nd, inplace = True)(res)
new_r = Elemwise(node.op, {})(*[transform(input) for input in node.inputs]) new_r = Elemwise(node.op, {})(*[transform(input) for input in node.inputs])
return new_r return new_r
ret = [] ret = []
......
...@@ -18,9 +18,9 @@ from op import \ ...@@ -18,9 +18,9 @@ from op import \
Op Op
from opt import \ from opt import \
Optimizer, SeqOptimizer, \ Optimizer, optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \ MergeOptimizer, MergeOptMerge, \
LocalOptimizer, LocalOptGroup, LocalOpKeyOptGroup, \ LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \
OpSub, OpRemove, PatternSub, \ OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer
......
...@@ -382,17 +382,17 @@ class Env(utils.object2): ...@@ -382,17 +382,17 @@ class Env(utils.object2):
"Same as len(self.clients(r))." "Same as len(self.clients(r))."
return len(self.clients(r)) return len(self.clients(r))
def edge(self, r): # def edge(self, r):
return r in self.inputs or r in self.orphans # return r in self.inputs or r in self.orphans
def follow(self, r): # def follow(self, r):
node = r.owner # node = r.owner
if self.edge(r): # if self.edge(r):
return None # return None
else: # else:
if node is None: # if node is None:
raise Exception("what the fuck") # raise Exception("what the fuck")
return node.inputs # return node.inputs
def check_integrity(self): def check_integrity(self):
""" """
......
...@@ -56,6 +56,16 @@ class Optimizer: ...@@ -56,6 +56,16 @@ class Optimizer:
pass pass
class FromFunctionOptimizer(Optimizer):
def __init__(self, fn):
self.apply = fn
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
def optimizer(f):
return FromFunctionOptimizer(f)
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
""" """
...@@ -137,6 +147,7 @@ class MergeOptimizer(Optimizer): ...@@ -137,6 +147,7 @@ class MergeOptimizer(Optimizer):
sig = r.signature() sig = r.signature()
other_r = inv_cid.get(sig, None) other_r = inv_cid.get(sig, None)
if other_r is not None: if other_r is not None:
if r.name: other_r.name = r.name
env.replace_validate(r, other_r) env.replace_validate(r, other_r)
else: else:
cid[r] = sig cid[r] = sig
...@@ -155,8 +166,12 @@ class MergeOptimizer(Optimizer): ...@@ -155,8 +166,12 @@ class MergeOptimizer(Optimizer):
success = False success = False
if dup is not None: if dup is not None:
success = True success = True
pairs = zip(node.outputs, dup.outputs)
for output, new_output in pairs:
if output.name and not new_output.name:
new_output.name = output.name
try: try:
env.replace_all_validate(zip(node.outputs, dup.outputs)) env.replace_all_validate(pairs)
except InconsistencyError, e: except InconsistencyError, e:
success = False success = False
if not success: if not success:
...@@ -189,17 +204,27 @@ class LocalOptimizer(utils.object2): ...@@ -189,17 +204,27 @@ class LocalOptimizer(utils.object2):
raise utils.AbstractFunctionError() raise utils.AbstractFunctionError()
class FromFunctionLocalOptimizer(LocalOptimizer):
def __init__(self, fn):
self.transform = fn
def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate)
def local_optimizer(f):
return FromFunctionLocalOptimizer(f)
class LocalOptGroup(LocalOptimizer): class LocalOptGroup(LocalOptimizer):
def __init__(self, optimizers): def __init__(self, *optimizers):
self.opts = optimizers self.opts = optimizers
self.reentrant = any(getattr(opt, 'reentrant', True), optimizers) self.reentrant = any(getattr(opt, 'reentrant', True) for opt in optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False), optimizers) self.retains_inputs = all(getattr(opt, 'retains_inputs', False) for opt in optimizers)
def transform(self, node): def transform(self, node):
for opt in self.opts: for opt in self.opts:
repl = opt.transform(node) repl = opt.transform(node)
if repl is not False: if repl:
return repl return repl
...@@ -547,3 +572,43 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -547,3 +572,43 @@ class OpKeyOptimizer(NavigatorOptimizer):
def keep_going(exc, nav, repl_pairs): def keep_going(exc, nav, repl_pairs):
pass pass
#################
### Utilities ###
#################
def _check_chain(r, chain):
chain = list(reversed(chain))
while chain:
elem = chain.pop()
if elem is None:
if not r.owner is None:
return False
elif r.owner is None:
return False
elif isinstance(elem, op.Op):
if not r.owner.op == elem:
return False
else:
try:
if issubclass(elem, op.Op) and not isinstance(r.owner.op, elem):
return False
except TypeError:
return False
if chain:
r = r.owner.inputs[chain.pop()]
return r
def check_chain(r, *chain):
if isinstance(r, graph.Apply):
r = r.outputs[0]
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论