提交 374afcb4 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

opt.CliqueOptimizer is working

上级 51276b41
......@@ -8,6 +8,7 @@ from tensor import Tensor
from gof import Env
from elemwise import DimShuffle
import numpy
import scalar_opt
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
......@@ -95,8 +96,8 @@ class _test_cliques(unittest.TestCase):
cliques = find_cliques(g)
assert len(cliques) == 2
(i1, o1), (i2, o2) = cliques
assert str(Env(i1, [o1])) == "[Broadcast{Add}(Broadcast{Add}(x, y), d)]"
assert str(Env(i2, [o2])) == "[Broadcast{Mul}(y, z)]"
assert str(Env(i1, o1)) == "[Broadcast{Add}(Broadcast{Add}(x, y), d)]"
assert str(Env(i2, o2)) == "[Broadcast{Mul}(y, z)]"
# print g
# for i, o in find_cliques(g):
# print "-->", Env(i, [o])
......@@ -112,6 +113,64 @@ class _test_cliques(unittest.TestCase):
# for i, o in find_cliques(g, True):
# print "-->", Env(i, [o])
# class _test_clique_opt(unittest.TestCase):
# def test_straightforward(self):
# x, y, z = inputs()
# e = x ** 2.0 #x * x
# g = Env([x], [e])
# gof.ConstantFinder().optimize(g)
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = scalar_opt.opt2,
# make_composite = False)
# print g
# opt.optimize(g)
# print g
# def test_inplace(self):
# x, y, z = inputs()
# #e = tensor.add_inplace(x, y + z)
# e = x + tensor.add_inplace(y, z)
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
# def test_straightforward(self):
# x, y, z = inputs()
# e = x + y + z
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
# def test_straightforward2(self):
# x, y, z = inputs()
# m = y * z
# d = tensor.dot(x, m)
# d.name = 'd'
# e = x + y + d
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
......
from gof import opt
from gof import opt, Env
import gof
from elemwise import Broadcast, DimShuffle
from gof.python25 import any, all
import scalar
class InplaceOptimizer(opt.OpSpecificOptimizer):
......@@ -86,12 +88,16 @@ lift_dimshuffle = DimShuffleLifter()
def find_cliques(env, through_broadcast = False):
def seek_from(r):
op = r.owner
if r in env.inputs \
or r in env.orphans() \
or op is None \
or not isinstance(op, Broadcast):
or not isinstance(op, Broadcast) \
or len(op.outputs) > 1:
# todo: handle multiple-output broadcast ops
# (needs to update the clique's outputs)
return None
ret = set()
......@@ -123,7 +129,7 @@ def find_cliques(env, through_broadcast = False):
for input in op.inputs:
find_cliques_helper(input)
else:
cliques.append((clique_inputs, r))
cliques.append((clique_inputs, [r]))
for input in clique_inputs:
find_cliques_helper(input)
......@@ -135,30 +141,124 @@ def find_cliques(env, through_broadcast = False):
return cliques
class CliqueOptimizer(opt.Optimizer):
def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False):
self.through_broadcast = through_broadcast
self.scalar_optimizer = scalar_optimizer
self.make_composite = make_composite
# class ElemwisePatternOptimizer(opt.Optimizer):
# def __init__(self, scalar_opt):
# self.
def apply(self, env):
if self.scalar_optimizer is None and not self.make_composite:
# there's nothing to do with the cliques...
return
cliques = find_cliques(env, self.through_broadcast)
opt = self.scalar_optimizer
def build_scalar_clique(r, env, equiv):
if r in equiv:
return equiv[r]
op = r.owner
if r in env.inputs or r in env.orphans():
s = scalar.Scalar(dtype = r.dtype)
_r = r
if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order):
_r = r.owner.inputs[0]
if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \
and _r.broadcastable == ():
s.data = _r.data
s.constant = True
equiv[r] = s
return s
s_op = op.scalar_opclass(*[build_scalar_clique(input, env, equiv) for input in op.inputs])
equiv[op] = s_op
for output, s_output in zip(op.outputs, s_op.outputs):
equiv[output] = s_output
return equiv[r]
for c_in, c_out in cliques:
equiv = dict()
g = Env(c_in, c_out)
for output in c_out:
build_scalar_clique(output, g, equiv)
s_g = Env([equiv[r] for r in g.inputs],
[equiv[r] for r in g.outputs])
if opt is not None:
equiv2 = dict()
for k, v in equiv.items():
equiv2[v] = k
def transform(op, equiv):
return Broadcast(op.__class__, [equiv[input] for input in op.inputs])
s_g.add_feature(sync_to(env, equiv2, transform))
opt.optimize(s_g)
if self.make_composite:
def follow_inplace(r):
op = r.owner
if op is None or r in g.inputs or r in g.orphans():
return None
assert isinstance(op, Broadcast)
destroyed = op.destroy_map().get(r, None)
if destroyed is None:
return None
else:
r2 = destroyed[0]
ret = follow_inplace(r2)
if ret is None:
return r2
else:
return ret
inplace_pattern = {}
for i, output in enumerate(g.outputs):
destroyed = follow_inplace(output)
if destroyed is not None and destroyed in g.inputs:
inplace_pattern[i] = g.inputs.index(destroyed)
C = scalar.composite(s_g.inputs, s_g.outputs)
ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern)
env.replace_all(dict((o, eco) for o, eco in zip(c_out, ec.outputs)))
def sync_to(target, equiv, transform):
class Synchronize(gof.Listener, gof.Constraint):
def __init__(self, source):
self.source = source
self.target = target
self.equiv = equiv
self.transform = transform
self.inconsistencies = []
def on_import(self, op1):
if op1 not in self.equiv:
op2 = self.transform(op1, self.equiv)
self.equiv[op1] = op2
for o1, o2 in zip(op1.outputs, op2.outputs):
self.equiv[o1] = o2
def on_prune(self, op1):
if op1 in self.equiv:
op2 = self.equiv[op1]
del self.equiv[op1]
for o1, o2 in zip(op1.outputs, op2.outputs):
del self.equiv[o1]
def on_rewire(self, clients1, r1, new_r1):
if (new_r1, r1) in self.inconsistencies:
self.inconsistencies.remove((new_r1, r1))
return
if not self.source.clients(r1):
try:
target.replace(self.equiv[r1], self.equiv[new_r1])
except:
self.inconsistencies.append((r1, new_r1))
# def synchronize(env1, env2, equiv, transform):
def validate(self):
if self.inconsistencies:
raise InconsistencyError("Could not synchronize when replacing the following pairs: %s" % self.inconsistencies)
return True
# class Synchronize(Listener, Constraint):
# def on_import(self, op1):
# if op1 not in equiv:
# equiv[op1] = transform(op1)
return Synchronize
# def on_prune(self, op1):
# if op1 in equiv:
# del equiv[op1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论