提交 8577d9ac authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed EquilibriumOptimizer

上级 16ca0db2
...@@ -237,6 +237,8 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -237,6 +237,8 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
return self._tracks return self._tracks
def add_requirements(self, env): def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
def __str__(self):
return getattr(self, 'name', '<FromFunctionLocalOptimizer instance>')
def local_optimizer(*tracks): def local_optimizer(*tracks):
def decorator(f): def decorator(f):
...@@ -492,6 +494,9 @@ class PatternSub(LocalOptimizer): ...@@ -492,6 +494,9 @@ class PatternSub(LocalOptimizer):
return str(pattern) return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern)) return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
def __repr__(self):
return str(self)
################## ##################
...@@ -512,7 +517,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -512,7 +517,7 @@ class NavigatorOptimizer(Optimizer):
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.failure_callback = failure_callback self.failure_callback = failure_callback
def attach_updater(self, env, importer, pruner): def attach_updater(self, env, importer, pruner, chin = None):
if self.ignore_newtrees: if self.ignore_newtrees:
importer = None importer = None
...@@ -526,6 +531,10 @@ class NavigatorOptimizer(Optimizer): ...@@ -526,6 +531,10 @@ class NavigatorOptimizer(Optimizer):
if pruner is not None: if pruner is not None:
def on_prune(self, env, node): def on_prune(self, env, node):
pruner(node) pruner(node)
if chin is not None:
def on_change_input(self, env, node, i, r, new_r):
chin(node, i, r, new_r)
u = Updater() u = Updater()
env.extend(u) env.extend(u)
return u return u
...@@ -728,9 +737,10 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -728,9 +737,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
for node in nodes: for node in nodes:
candidates = filter(node, depth) candidates = filter(node, depth)
depth += 1 depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__, nodes = reduce(list.__iadd__,
[reduce(list.__iadd__, [reduce(list.__iadd__,
[[n for n, i in out.clients] for out in node.outputs], [[n for n, i in out.clients if not isinstance(n, str)] for out in node.outputs],
[]) for node in nodes], []) for node in nodes],
[]) [])
candidates = tracks candidates = tracks
...@@ -746,12 +756,16 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -746,12 +756,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
runs = None runs = None
def importer(node): def importer(node):
#print 'IMPORTING', node
self.backtrack(node, tasks) self.backtrack(node, tasks)
def pruner(node): def pruner(node):
try: try:
del tasks[node] del tasks[node]
except KeyError: except KeyError:
pass pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == # # # == NOT IDEAL == #
# for node in env.nodes: # for node in env.nodes:
...@@ -761,7 +775,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -761,7 +775,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
for node in env.nodes: for node in env.nodes:
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op)) tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner) u = self.attach_updater(env, importer, pruner, chin)
while tasks: while tasks:
for node in tasks.iterkeys(): for node in tasks.iterkeys():
todo = tasks.pop(node) todo = tasks.pop(node)
...@@ -806,6 +820,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -806,6 +820,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def keep_going(exc, nav, repl_pairs): def keep_going(exc, nav, repl_pairs):
"""WRITEME""" """WRITEME"""
print exc, nav, repl_pairs
pass pass
......
...@@ -44,16 +44,22 @@ class MyOp(Op): ...@@ -44,16 +44,22 @@ class MyOp(Op):
def __str__(self): def __str__(self):
return self.name return self.name
def __repr__(self):
return self.name
def __eq__(self, other): def __eq__(self, other):
return self is other or isinstance(other, MyOp) and self.x is not None and self.x == other.x return self is other or isinstance(other, MyOp) and self.x is not None and self.x == other.x
def __hash__(self): def __hash__(self):
return self.x if self.x is not None else id(self) return self.x if self.x is not None else id(self)
op1 = MyOp('Op1') op1 = MyOp('Op1')
op2 = MyOp('Op2') op2 = MyOp('Op2')
op3 = MyOp('Op3') op3 = MyOp('Op3')
op4 = MyOp('Op4') op4 = MyOp('Op4')
op5 = MyOp('Op5')
op6 = MyOp('Op6')
op_d = MyOp('OpD', {0: [0]}) op_d = MyOp('OpD', {0: [0]})
op_y = MyOp('OpY', x = 1) op_y = MyOp('OpY', x = 1)
...@@ -349,6 +355,22 @@ class TestEquilibrium(object): ...@@ -349,6 +355,22 @@ class TestEquilibrium(object):
print g print g
assert str(g) == '[Op2(x, y)]' assert str(g) == '[Op2(x, y)]'
def test_2(self):
x, y, z = map(MyResult, 'xyz')
e = op1(op1(op3(x, y)))
g = Env([x, y, z], [e])
print g
opt = EquilibriumOptimizer(
[PatternSub((op1, (op2, 'x', 'y')), (op4, 'x', 'y')),
PatternSub((op3, 'x', 'y'), (op4, 'x', 'y')),
PatternSub((op4, 'x', 'y'), (op5, 'x', 'y')),
PatternSub((op5, 'x', 'y'), (op6, 'x', 'y')),
PatternSub((op6, 'x', 'y'), (op2, 'x', 'y'))
],
max_use_ratio = 10)
opt.optimize(g)
assert str(g) == '[Op2(x, y)]'
def test_low_use_ratio(self): def test_low_use_ratio(self):
x, y, z = map(MyResult, 'xyz') x, y, z = map(MyResult, 'xyz')
e = op3(op4(x, y)) e = op3(op4(x, y))
......
...@@ -189,8 +189,8 @@ class DimShufflePrinter: ...@@ -189,8 +189,8 @@ class DimShufflePrinter:
def __p(self, new_order, pstate, r): def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x': if new_order != () and new_order[0] == 'x':
# return "%s" % self.__p(new_order[1:], pstate, r) return "%s" % self.__p(new_order[1:], pstate, r)
return "[%s]" % self.__p(new_order[1:], pstate, r) # return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == range(r.type.ndim): if list(new_order) == range(r.type.ndim):
return pstate.pprinter.process(r) return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))): if list(new_order) == list(reversed(range(r.type.ndim))):
......
...@@ -92,7 +92,6 @@ insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer) ...@@ -92,7 +92,6 @@ insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
inplace_optimizer = gof.InplaceOptimizer( inplace_optimizer = gof.InplaceOptimizer(
gof.SeqOptimizer(out2in(gemm_pattern_1), gof.SeqOptimizer(out2in(gemm_pattern_1),
#out2in(dot_to_gemm),
insert_inplace_optimizer, insert_inplace_optimizer,
failure_callback = gof.keep_going)) failure_callback = gof.keep_going))
compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run') compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run')
...@@ -491,7 +490,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -491,7 +490,7 @@ class Canonizer(gof.LocalOptimizer):
ct = [self.calculate(numct, denumct, aslist = False)] ct = [self.calculate(numct, denumct, aslist = False)]
# if len(ct) and ncc == 1 and dcc == 0: # if len(ct) and ncc == 1 and dcc == 0:
# return orig_num, orig_denum # return orig_num, orig_denum
if orig_num and N.all(ct == self.get_constant(orig_num[0])): if orig_num and len(numct) == 1 and ct and N.all(ct == self.get_constant(orig_num[0])):
return orig_num, orig_denum return orig_num, orig_denum
return ct + num, denum return ct + num, denum
...@@ -710,6 +709,20 @@ register_canonicalize(local_greedy_distributor) ...@@ -710,6 +709,20 @@ register_canonicalize(local_greedy_distributor)
@gof.local_optimizer([None])
def constant_folding(node):
for input in node.inputs:
if not isinstance(input, gof.Constant):
return False
storage = [[None] for output in node.outputs]
node.op.perform(node, [x.data for x in node.inputs], storage)
return [gof.Constant(output.type, s[0]) for s, output in zip(storage, node.outputs)]
register_canonicalize(constant_folding)
# def _math_optimizer(): # def _math_optimizer():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论