提交 8577d9ac authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed EquilibriumOptimizer

上级 16ca0db2
......@@ -237,6 +237,8 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
return self._tracks
def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate())
def __str__(self):
return getattr(self, 'name', '<FromFunctionLocalOptimizer instance>')
def local_optimizer(*tracks):
def decorator(f):
......@@ -492,6 +494,9 @@ class PatternSub(LocalOptimizer):
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
def __repr__(self):
return str(self)
##################
......@@ -512,7 +517,7 @@ class NavigatorOptimizer(Optimizer):
self.ignore_newtrees = ignore_newtrees
self.failure_callback = failure_callback
def attach_updater(self, env, importer, pruner):
def attach_updater(self, env, importer, pruner, chin = None):
if self.ignore_newtrees:
importer = None
......@@ -526,6 +531,10 @@ class NavigatorOptimizer(Optimizer):
if pruner is not None:
def on_prune(self, env, node):
pruner(node)
if chin is not None:
def on_change_input(self, env, node, i, r, new_r):
chin(node, i, r, new_r)
u = Updater()
env.extend(u)
return u
......@@ -728,9 +737,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
for node in nodes:
candidates = filter(node, depth)
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients] for out in node.outputs],
[[n for n, i in out.clients if not isinstance(n, str)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
......@@ -746,12 +756,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
runs = None
def importer(node):
#print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in env.nodes:
......@@ -761,7 +775,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
for node in env.nodes:
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner)
u = self.attach_updater(env, importer, pruner, chin)
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
......@@ -806,6 +820,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def keep_going(exc, nav, repl_pairs):
"""WRITEME"""
print exc, nav, repl_pairs
pass
......
......@@ -44,16 +44,22 @@ class MyOp(Op):
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
return self is other or isinstance(other, MyOp) and self.x is not None and self.x == other.x
def __hash__(self):
return self.x if self.x is not None else id(self)
op1 = MyOp('Op1')
op2 = MyOp('Op2')
op3 = MyOp('Op3')
op4 = MyOp('Op4')
op5 = MyOp('Op5')
op6 = MyOp('Op6')
op_d = MyOp('OpD', {0: [0]})
op_y = MyOp('OpY', x = 1)
......@@ -349,6 +355,22 @@ class TestEquilibrium(object):
print g
assert str(g) == '[Op2(x, y)]'
def test_2(self):
x, y, z = map(MyResult, 'xyz')
e = op1(op1(op3(x, y)))
g = Env([x, y, z], [e])
print g
opt = EquilibriumOptimizer(
[PatternSub((op1, (op2, 'x', 'y')), (op4, 'x', 'y')),
PatternSub((op3, 'x', 'y'), (op4, 'x', 'y')),
PatternSub((op4, 'x', 'y'), (op5, 'x', 'y')),
PatternSub((op5, 'x', 'y'), (op6, 'x', 'y')),
PatternSub((op6, 'x', 'y'), (op2, 'x', 'y'))
],
max_use_ratio = 10)
opt.optimize(g)
assert str(g) == '[Op2(x, y)]'
def test_low_use_ratio(self):
x, y, z = map(MyResult, 'xyz')
e = op3(op4(x, y))
......
......@@ -189,8 +189,8 @@ class DimShufflePrinter:
def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x':
# return "%s" % self.__p(new_order[1:], pstate, r)
return "[%s]" % self.__p(new_order[1:], pstate, r)
return "%s" % self.__p(new_order[1:], pstate, r)
# return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == range(r.type.ndim):
return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))):
......
......@@ -92,7 +92,6 @@ insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
inplace_optimizer = gof.InplaceOptimizer(
gof.SeqOptimizer(out2in(gemm_pattern_1),
#out2in(dot_to_gemm),
insert_inplace_optimizer,
failure_callback = gof.keep_going))
compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run')
......@@ -491,7 +490,7 @@ class Canonizer(gof.LocalOptimizer):
ct = [self.calculate(numct, denumct, aslist = False)]
# if len(ct) and ncc == 1 and dcc == 0:
# return orig_num, orig_denum
if orig_num and N.all(ct == self.get_constant(orig_num[0])):
if orig_num and len(numct) == 1 and ct and N.all(ct == self.get_constant(orig_num[0])):
return orig_num, orig_denum
return ct + num, denum
......@@ -710,6 +709,20 @@ register_canonicalize(local_greedy_distributor)
@gof.local_optimizer([None])
def constant_folding(node):
for input in node.inputs:
if not isinstance(input, gof.Constant):
return False
storage = [[None] for output in node.outputs]
node.op.perform(node, [x.data for x in node.inputs], storage)
return [gof.Constant(output.type, s[0]) for s, output in zip(storage, node.outputs)]
register_canonicalize(constant_folding)
# def _math_optimizer():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论