提交 374afcb4 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

opt.CliqueOptimizer is working

上级 51276b41
...@@ -8,6 +8,7 @@ from tensor import Tensor ...@@ -8,6 +8,7 @@ from tensor import Tensor
from gof import Env from gof import Env
from elemwise import DimShuffle from elemwise import DimShuffle
import numpy import numpy
import scalar_opt
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
...@@ -95,8 +96,8 @@ class _test_cliques(unittest.TestCase): ...@@ -95,8 +96,8 @@ class _test_cliques(unittest.TestCase):
cliques = find_cliques(g) cliques = find_cliques(g)
assert len(cliques) == 2 assert len(cliques) == 2
(i1, o1), (i2, o2) = cliques (i1, o1), (i2, o2) = cliques
assert str(Env(i1, [o1])) == "[Broadcast{Add}(Broadcast{Add}(x, y), d)]" assert str(Env(i1, o1)) == "[Broadcast{Add}(Broadcast{Add}(x, y), d)]"
assert str(Env(i2, [o2])) == "[Broadcast{Mul}(y, z)]" assert str(Env(i2, o2)) == "[Broadcast{Mul}(y, z)]"
# print g # print g
# for i, o in find_cliques(g): # for i, o in find_cliques(g):
# print "-->", Env(i, [o]) # print "-->", Env(i, [o])
...@@ -113,6 +114,64 @@ class _test_cliques(unittest.TestCase): ...@@ -113,6 +114,64 @@ class _test_cliques(unittest.TestCase):
# print "-->", Env(i, [o]) # print "-->", Env(i, [o])
# class _test_clique_opt(unittest.TestCase):
# def test_straightforward(self):
# x, y, z = inputs()
# e = x ** 2.0 #x * x
# g = Env([x], [e])
# gof.ConstantFinder().optimize(g)
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = scalar_opt.opt2,
# make_composite = False)
# print g
# opt.optimize(g)
# print g
# def test_inplace(self):
# x, y, z = inputs()
# #e = tensor.add_inplace(x, y + z)
# e = x + tensor.add_inplace(y, z)
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
# def test_straightforward(self):
# x, y, z = inputs()
# e = x + y + z
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
# def test_straightforward2(self):
# x, y, z = inputs()
# m = y * z
# d = tensor.dot(x, m)
# d.name = 'd'
# e = x + y + d
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
if __name__ == '__main__': if __name__ == '__main__':
......
from gof import opt from gof import opt, Env
import gof
from elemwise import Broadcast, DimShuffle from elemwise import Broadcast, DimShuffle
from gof.python25 import any, all from gof.python25 import any, all
import scalar
class InplaceOptimizer(opt.OpSpecificOptimizer): class InplaceOptimizer(opt.OpSpecificOptimizer):
...@@ -86,12 +88,16 @@ lift_dimshuffle = DimShuffleLifter() ...@@ -86,12 +88,16 @@ lift_dimshuffle = DimShuffleLifter()
def find_cliques(env, through_broadcast = False): def find_cliques(env, through_broadcast = False):
def seek_from(r): def seek_from(r):
op = r.owner op = r.owner
if r in env.inputs \ if r in env.inputs \
or r in env.orphans() \ or r in env.orphans() \
or op is None \ or op is None \
or not isinstance(op, Broadcast): or not isinstance(op, Broadcast) \
or len(op.outputs) > 1:
# todo: handle multiple-output broadcast ops
# (needs to update the clique's outputs)
return None return None
ret = set() ret = set()
...@@ -123,7 +129,7 @@ def find_cliques(env, through_broadcast = False): ...@@ -123,7 +129,7 @@ def find_cliques(env, through_broadcast = False):
for input in op.inputs: for input in op.inputs:
find_cliques_helper(input) find_cliques_helper(input)
else: else:
cliques.append((clique_inputs, r)) cliques.append((clique_inputs, [r]))
for input in clique_inputs: for input in clique_inputs:
find_cliques_helper(input) find_cliques_helper(input)
...@@ -135,30 +141,124 @@ def find_cliques(env, through_broadcast = False): ...@@ -135,30 +141,124 @@ def find_cliques(env, through_broadcast = False):
return cliques return cliques
class CliqueOptimizer(opt.Optimizer):
def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False):
self.through_broadcast = through_broadcast
self.scalar_optimizer = scalar_optimizer
self.make_composite = make_composite
# class ElemwisePatternOptimizer(opt.Optimizer): def apply(self, env):
if self.scalar_optimizer is None and not self.make_composite:
# def __init__(self, scalar_opt): # there's nothing to do with the cliques...
# self. return
cliques = find_cliques(env, self.through_broadcast)
opt = self.scalar_optimizer
# def synchronize(env1, env2, equiv, transform): def build_scalar_clique(r, env, equiv):
if r in equiv:
return equiv[r]
op = r.owner
if r in env.inputs or r in env.orphans():
s = scalar.Scalar(dtype = r.dtype)
_r = r
if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order):
_r = r.owner.inputs[0]
if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \
and _r.broadcastable == ():
s.data = _r.data
s.constant = True
equiv[r] = s
return s
s_op = op.scalar_opclass(*[build_scalar_clique(input, env, equiv) for input in op.inputs])
equiv[op] = s_op
for output, s_output in zip(op.outputs, s_op.outputs):
equiv[output] = s_output
return equiv[r]
for c_in, c_out in cliques:
equiv = dict()
g = Env(c_in, c_out)
for output in c_out:
build_scalar_clique(output, g, equiv)
s_g = Env([equiv[r] for r in g.inputs],
[equiv[r] for r in g.outputs])
if opt is not None:
equiv2 = dict()
for k, v in equiv.items():
equiv2[v] = k
def transform(op, equiv):
return Broadcast(op.__class__, [equiv[input] for input in op.inputs])
s_g.add_feature(sync_to(env, equiv2, transform))
opt.optimize(s_g)
if self.make_composite:
def follow_inplace(r):
op = r.owner
if op is None or r in g.inputs or r in g.orphans():
return None
assert isinstance(op, Broadcast)
destroyed = op.destroy_map().get(r, None)
if destroyed is None:
return None
else:
r2 = destroyed[0]
ret = follow_inplace(r2)
if ret is None:
return r2
else:
return ret
inplace_pattern = {}
for i, output in enumerate(g.outputs):
destroyed = follow_inplace(output)
if destroyed is not None and destroyed in g.inputs:
inplace_pattern[i] = g.inputs.index(destroyed)
C = scalar.composite(s_g.inputs, s_g.outputs)
ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern)
env.replace_all(dict((o, eco) for o, eco in zip(c_out, ec.outputs)))
def sync_to(target, equiv, transform):
class Synchronize(gof.Listener, gof.Constraint):
def __init__(self, source):
self.source = source
self.target = target
self.equiv = equiv
self.transform = transform
self.inconsistencies = []
def on_import(self, op1):
if op1 not in self.equiv:
op2 = self.transform(op1, self.equiv)
self.equiv[op1] = op2
for o1, o2 in zip(op1.outputs, op2.outputs):
self.equiv[o1] = o2
def on_prune(self, op1):
if op1 in self.equiv:
op2 = self.equiv[op1]
del self.equiv[op1]
for o1, o2 in zip(op1.outputs, op2.outputs):
del self.equiv[o1]
def on_rewire(self, clients1, r1, new_r1):
if (new_r1, r1) in self.inconsistencies:
self.inconsistencies.remove((new_r1, r1))
return
if not self.source.clients(r1):
try:
target.replace(self.equiv[r1], self.equiv[new_r1])
except:
self.inconsistencies.append((r1, new_r1))
# class Synchronize(Listener, Constraint): def validate(self):
if self.inconsistencies:
raise InconsistencyError("Could not synchronize when replacing the following pairs: %s" % self.inconsistencies)
return True
# def on_import(self, op1): return Synchronize
# if op1 not in equiv:
# equiv[op1] = transform(op1)
# def on_prune(self, op1):
# if op1 in equiv:
# del equiv[op1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论