提交 cca46fc2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

wrote the NavigatorOptimizer interface to apply local optimizations to nodes…

wrote the NavigatorOptimizer interface to apply local optimizations to nodes following various traversal orders - fixes #79 with TopoOptimizer
上级 b801ba52
from cc import CLinker, OpWiseCLinker, DualLinker from cc import \
from env import InconsistencyError, Env CLinker, OpWiseCLinker, DualLinker
from ext import DestroyHandler, view_roots
from graph import Apply, Result, Constant, Value from env import \
from link import Linker, LocalLinker, PerformLinker, MetaLinker, Profiler InconsistencyError, Env
from op import Op, Macro
from opt import Optimizer, DummyOpt, SeqOptimizer, LocalOptimizer, OpSpecificOptimizer, OpSubOptimizer, OpRemover, PatternOptimizer, MergeOptimizer, MergeOptMerge from ext import \
from toolbox import Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener DestroyHandler, view_roots
from type import Type, Generic, generic
from utils import object2, AbstractFunctionError from graph import \
Apply, Result, Constant, Value
from link import \
Linker, LocalLinker, PerformLinker, MetaLinker, Profiler
from op import \
Op, Macro
from opt import \
Optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \
LocalOptimizer, LocalOptGroup, LocalOpKeyOptGroup, \
ExpandMacro, OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, \
expand_macros
from toolbox import \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
from type import \
Type, Generic, generic
from utils import \
object2, AbstractFunctionError
...@@ -10,7 +10,8 @@ from toolbox import * ...@@ -10,7 +10,8 @@ from toolbox import *
def as_result(x): def as_result(x):
assert isinstance(x, Result) if not isinstance(x, Result):
raise TypeError("not a Result", x)
return x return x
...@@ -69,6 +70,9 @@ def inputs(): ...@@ -69,6 +70,9 @@ def inputs():
return x, y, z return x, y, z
PatternOptimizer = lambda p1, p2, ign=False: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
TopoPatternOptimizer = lambda p1, p2, ign=True: TopoOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
class _test_PatternOptimizer(unittest.TestCase): class _test_PatternOptimizer(unittest.TestCase):
def test_replace_output(self): def test_replace_output(self):
...@@ -116,13 +120,14 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -116,13 +120,14 @@ class _test_PatternOptimizer(unittest.TestCase):
assert str(g) == "[Op1(Op1(y, x), z)]" assert str(g) == "[Op1(Op1(y, x), z)]"
def test_no_recurse(self): def test_no_recurse(self):
# if the out pattern is an acceptable in pattern, # if the out pattern is an acceptable in pattern
# and that the ignore_newtrees flag is True,
# it should do the replacement and stop # it should do the replacement and stop
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(op2, '2', '1')).optimize(g) (op2, '2', '1'), ign=True).optimize(g)
assert str(g) == "[Op1(Op2(y, x), z)]" assert str(g) == "[Op1(Op2(y, x), z)]"
def test_multiple(self): def test_multiple(self):
...@@ -157,20 +162,19 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -157,20 +162,19 @@ class _test_PatternOptimizer(unittest.TestCase):
e = op1(op1(op1(x))) e = op1(op1(op1(x)))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((op1, '1'), PatternOptimizer((op1, '1'),
(op2, (op1, '1'))).optimize(g) (op2, (op1, '1')), ign=True).optimize(g)
assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]" assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]"
# def test_ambiguous(self): def test_ambiguous(self):
# # this test is known to fail most of the time # this test should always work with TopoOptimizer and the
# # the reason is that PatternOptimizer doesn't go through # ignore_newtrees flag set to False. Behavior with ignore_newtrees
# # the ops in topological order. The order is random and # = True or with other NavigatorOptimizers may differ.
# # it does not visit ops that it creates. x, y, z = inputs()
# x, y, z = inputs() e = op1(op1(op1(op1(op1(x)))))
# e = op1(op1(op1(op1(op1(x))))) g = Env([x, y, z], [e])
# g = Env([x, y, z], [e]) TopoPatternOptimizer((op1, (op1, '1')),
# PatternOptimizer((op1, (op1, '1')), (op1, '1'), ign=False).optimize(g)
# (op1, '1')).optimize(g) assert str(g) == "[Op1(x)]"
# assert str(g) == "[Op1(x)]"
def test_constant_unification(self): def test_constant_unification(self):
x = Constant(MyType(), 2, name = 'x') x = Constant(MyType(), 2, name = 'x')
...@@ -186,7 +190,7 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -186,7 +190,7 @@ class _test_PatternOptimizer(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = op4(op1(op2(x, y)), op1(op1(x, y))) e = op4(op1(op2(x, y)), op1(op1(x, y)))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
def constraint(env, r): def constraint(r):
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
return r.owner.op == op2 return r.owner.op == op2
PatternOptimizer((op1, {'pattern': '1', PatternOptimizer((op1, {'pattern': '1',
...@@ -206,7 +210,7 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -206,7 +210,7 @@ class _test_PatternOptimizer(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = op2(op1(x, x), op1(x, y)) e = op2(op1(x, x), op1(x, y))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
def constraint(env, r): def constraint(r):
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
return r.owner.inputs[0] is not r.owner.inputs[1] return r.owner.inputs[0] is not r.owner.inputs[1]
PatternOptimizer({'pattern': (op1, 'x', 'y'), PatternOptimizer({'pattern': (op1, 'x', 'y'),
...@@ -263,6 +267,9 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -263,6 +267,9 @@ class _test_PatternOptimizer(unittest.TestCase):
# assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]" # assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]"
OpSubOptimizer = lambda op1, op2: TopoOptimizer(OpSub(op1, op2))
OpSubOptimizer = lambda op1, op2: OpKeyOptimizer(OpSub(op1, op2))
class _test_OpSubOptimizer(unittest.TestCase): class _test_OpSubOptimizer(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
...@@ -413,7 +420,20 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -413,7 +420,20 @@ class _test_MergeOptimizer(unittest.TestCase):
# assert not getattr(x, 'constant', False) and z.constant # assert not getattr(x, 'constant', False) and z.constant
# MergeOptimizer().optimize(g) # MergeOptimizer().optimize(g)
reenter = Exception("Re-Entered")
class LoopyMacro(Macro):
def __init__(self):
self.counter = 0
def make_node(self, x, y):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
x, y = node.inputs
if self.counter > 0:
raise reenter
self.counter += 1
return [self(y, x)]
def __str__(self):
return "loopy_macro"
class _test_ExpandMacro(unittest.TestCase): class _test_ExpandMacro(unittest.TestCase):
...@@ -423,26 +443,31 @@ class _test_ExpandMacro(unittest.TestCase): ...@@ -423,26 +443,31 @@ class _test_ExpandMacro(unittest.TestCase):
return Apply(self, [x, y], [MyType()()]) return Apply(self, [x, y], [MyType()()])
def expand(self, node): def expand(self, node):
return [op1(y, x)] return [op1(y, x)]
def __str__(self):
return "macro"
x, y, z = inputs() x, y, z = inputs()
e = Macro1()(x, y) e = Macro1()(x, y)
g = Env([x, y], [e]) g = Env([x, y], [e])
print g
expand_macros.optimize(g) expand_macros.optimize(g)
print g assert str(g) == "[Op1(y, x)]"
def test_loopy(self): def test_loopy_1(self):
class Macro1(Macro):
def make_node(self, x, y):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
return [Macro1()(y, x)]
x, y, z = inputs() x, y, z = inputs()
e = Macro1()(x, y) e = LoopyMacro()(x, y)
g = Env([x, y], [e])
TopoOptimizer(ExpandMacro(), ignore_newtrees = True).optimize(g)
assert str(g) == "[loopy_macro(y, x)]"
def test_loopy_2(self):
x, y, z = inputs()
e = LoopyMacro()(x, y)
g = Env([x, y], [e]) g = Env([x, y], [e])
print g try:
#expand_macros.optimize(g) TopoOptimizer(ExpandMacro(), ignore_newtrees = False).optimize(g)
TopDownOptimizer(ExpandMacro(), ignore_newtrees = True).optimize(g) self.fail("should not arrive here")
print g except Exception, e:
if e is not reenter:
raise
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from copy import copy from copy import copy
import graph import graph
import utils import utils
import toolbox
class InconsistencyError(Exception): class InconsistencyError(Exception):
...@@ -273,14 +274,13 @@ class Env(utils.object2): ...@@ -273,14 +274,13 @@ class Env(utils.object2):
""" """
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
self._features.append(feature)
attach = getattr(feature, 'on_attach', None) attach = getattr(feature, 'on_attach', None)
if attach is not None: if attach is not None:
try: try:
attach(self) attach(self)
except: except toolbox.AlreadyThere:
self._features.pop() return
raise self._features.append(feature)
def remove_feature(self, feature): def remove_feature(self, feature):
""" """
......
...@@ -72,7 +72,7 @@ class DestroyHandlerHelper(toolbox.Bookkeeper): ...@@ -72,7 +72,7 @@ class DestroyHandlerHelper(toolbox.Bookkeeper):
raise Exception("A DestroyHandler instance can only serve one Env.") raise Exception("A DestroyHandler instance can only serve one Env.")
for attr in ('destroyers', 'destroy_handler'): for attr in ('destroyers', 'destroy_handler'):
if hasattr(env, attr): if hasattr(env, attr):
raise Exception("DestroyHandler feature is already present or in conflict with another plugin.") raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
def __destroyers(r): def __destroyers(r):
ret = self.destroyers.get(r, {}) ret = self.destroyers.get(r, {})
......
差异被折叠。
...@@ -3,6 +3,10 @@ from functools import partial ...@@ -3,6 +3,10 @@ from functools import partial
import graph import graph
class AlreadyThere(Exception):
pass
class Bookkeeper: class Bookkeeper:
def on_attach(self, env): def on_attach(self, env):
...@@ -21,7 +25,7 @@ class History: ...@@ -21,7 +25,7 @@ class History:
def on_attach(self, env): def on_attach(self, env):
if hasattr(env, 'checkpoint') or hasattr(env, 'revert'): if hasattr(env, 'checkpoint') or hasattr(env, 'revert'):
raise Exception("History feature is already present or in conflict with another plugin.") raise AlreadyThere("History feature is already present or in conflict with another plugin.")
self.history[env] = [] self.history[env] = []
env.checkpoint = lambda: len(self.history[env]) env.checkpoint = lambda: len(self.history[env])
env.revert = partial(self.revert, env) env.revert = partial(self.revert, env)
...@@ -55,7 +59,7 @@ class Validator: ...@@ -55,7 +59,7 @@ class Validator:
def on_attach(self, env): def on_attach(self, env):
if hasattr(env, 'validate'): if hasattr(env, 'validate'):
raise Exception("Validator feature is already present or in conflict with another plugin.") raise AlreadyThere("Validator feature is already present or in conflict with another plugin.")
env.validate = lambda: env.execute_callbacks('validate') env.validate = lambda: env.execute_callbacks('validate')
def consistent(): def consistent():
try: try:
...@@ -77,7 +81,7 @@ class ReplaceValidate(History, Validator): ...@@ -77,7 +81,7 @@ class ReplaceValidate(History, Validator):
Validator.on_attach(self, env) Validator.on_attach(self, env)
for attr in ('replace_validate', 'replace_all_validate'): for attr in ('replace_validate', 'replace_all_validate'):
if hasattr(env, attr): if hasattr(env, attr):
raise Exception("ReplaceValidate feature is already present or in conflict with another plugin.") raise AlreadyThere("ReplaceValidate feature is already present or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env) env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env) env.replace_all_validate = partial(self.replace_all_validate, env)
...@@ -110,7 +114,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -110,7 +114,7 @@ class NodeFinder(dict, Bookkeeper):
if self.env is not None: if self.env is not None:
raise Exception("A NodeFinder instance can only serve one Env.") raise Exception("A NodeFinder instance can only serve one Env.")
if hasattr(env, 'get_nodes'): if hasattr(env, 'get_nodes'):
raise Exception("NodeFinder is already present or in conflict with another plugin.") raise AlreadyThere("NodeFinder is already present or in conflict with another plugin.")
self.env = env self.env = env
env.get_nodes = partial(self.query, env) env.get_nodes = partial(self.query, env)
Bookkeeper.on_attach(self, env) Bookkeeper.on_attach(self, env)
...@@ -143,10 +147,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -143,10 +147,7 @@ class NodeFinder(dict, Bookkeeper):
except TypeError: except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op) raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op)
all = list(all) all = list(all)
while all: return all
next = all.pop()
if next in env.nodes:
yield next
class PrintListener(object): class PrintListener(object):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论