提交 cca46fc2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

wrote the NavigatorOptimizer interface to apply local optimizations to nodes…

wrote the NavigatorOptimizer interface to apply local optimizations to nodes following various traversal orders - fixes #79 with TopoOptimizer
上级 b801ba52
from cc import CLinker, OpWiseCLinker, DualLinker
from env import InconsistencyError, Env
from ext import DestroyHandler, view_roots
from graph import Apply, Result, Constant, Value
from link import Linker, LocalLinker, PerformLinker, MetaLinker, Profiler
from op import Op, Macro
from opt import Optimizer, DummyOpt, SeqOptimizer, LocalOptimizer, OpSpecificOptimizer, OpSubOptimizer, OpRemover, PatternOptimizer, MergeOptimizer, MergeOptMerge
from toolbox import Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
from type import Type, Generic, generic
from utils import object2, AbstractFunctionError
from cc import \
CLinker, OpWiseCLinker, DualLinker
from env import \
InconsistencyError, Env
from ext import \
DestroyHandler, view_roots
from graph import \
Apply, Result, Constant, Value
from link import \
Linker, LocalLinker, PerformLinker, MetaLinker, Profiler
from op import \
Op, Macro
from opt import \
Optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \
LocalOptimizer, LocalOptGroup, LocalOpKeyOptGroup, \
ExpandMacro, OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, \
expand_macros
from toolbox import \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
from type import \
Type, Generic, generic
from utils import \
object2, AbstractFunctionError
......@@ -10,7 +10,8 @@ from toolbox import *
def as_result(x):
assert isinstance(x, Result)
if not isinstance(x, Result):
raise TypeError("not a Result", x)
return x
......@@ -69,6 +70,9 @@ def inputs():
return x, y, z
PatternOptimizer = lambda p1, p2, ign=False: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
TopoPatternOptimizer = lambda p1, p2, ign=True: TopoOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
class _test_PatternOptimizer(unittest.TestCase):
def test_replace_output(self):
......@@ -116,13 +120,14 @@ class _test_PatternOptimizer(unittest.TestCase):
assert str(g) == "[Op1(Op1(y, x), z)]"
def test_no_recurse(self):
# if the out pattern is an acceptable in pattern,
# if the out pattern is an acceptable in pattern
# and that the ignore_newtrees flag is True,
# it should do the replacement and stop
x, y, z = inputs()
e = op1(op2(x, y), z)
g = Env([x, y, z], [e])
PatternOptimizer((op2, '1', '2'),
(op2, '2', '1')).optimize(g)
(op2, '2', '1'), ign=True).optimize(g)
assert str(g) == "[Op1(Op2(y, x), z)]"
def test_multiple(self):
......@@ -157,20 +162,19 @@ class _test_PatternOptimizer(unittest.TestCase):
e = op1(op1(op1(x)))
g = Env([x, y, z], [e])
PatternOptimizer((op1, '1'),
(op2, (op1, '1'))).optimize(g)
(op2, (op1, '1')), ign=True).optimize(g)
assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]"
# def test_ambiguous(self):
# # this test is known to fail most of the time
# # the reason is that PatternOptimizer doesn't go through
# # the ops in topological order. The order is random and
# # it does not visit ops that it creates.
# x, y, z = inputs()
# e = op1(op1(op1(op1(op1(x)))))
# g = Env([x, y, z], [e])
# PatternOptimizer((op1, (op1, '1')),
# (op1, '1')).optimize(g)
# assert str(g) == "[Op1(x)]"
def test_ambiguous(self):
# this test should always work with TopoOptimizer and the
# ignore_newtrees flag set to False. Behavior with ignore_newtrees
# = True or with other NavigatorOptimizers may differ.
x, y, z = inputs()
e = op1(op1(op1(op1(op1(x)))))
g = Env([x, y, z], [e])
TopoPatternOptimizer((op1, (op1, '1')),
(op1, '1'), ign=False).optimize(g)
assert str(g) == "[Op1(x)]"
def test_constant_unification(self):
x = Constant(MyType(), 2, name = 'x')
......@@ -186,7 +190,7 @@ class _test_PatternOptimizer(unittest.TestCase):
x, y, z = inputs()
e = op4(op1(op2(x, y)), op1(op1(x, y)))
g = Env([x, y, z], [e])
def constraint(env, r):
def constraint(r):
# Only replacing if the input is an instance of Op2
return r.owner.op == op2
PatternOptimizer((op1, {'pattern': '1',
......@@ -206,7 +210,7 @@ class _test_PatternOptimizer(unittest.TestCase):
x, y, z = inputs()
e = op2(op1(x, x), op1(x, y))
g = Env([x, y, z], [e])
def constraint(env, r):
def constraint(r):
# Only replacing if the input is an instance of Op2
return r.owner.inputs[0] is not r.owner.inputs[1]
PatternOptimizer({'pattern': (op1, 'x', 'y'),
......@@ -263,6 +267,9 @@ class _test_PatternOptimizer(unittest.TestCase):
# assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]"
OpSubOptimizer = lambda op1, op2: TopoOptimizer(OpSub(op1, op2))
OpSubOptimizer = lambda op1, op2: OpKeyOptimizer(OpSub(op1, op2))
class _test_OpSubOptimizer(unittest.TestCase):
def test_straightforward(self):
......@@ -413,7 +420,20 @@ class _test_MergeOptimizer(unittest.TestCase):
# assert not getattr(x, 'constant', False) and z.constant
# MergeOptimizer().optimize(g)
reenter = Exception("Re-Entered")
class LoopyMacro(Macro):
def __init__(self):
self.counter = 0
def make_node(self, x, y):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
x, y = node.inputs
if self.counter > 0:
raise reenter
self.counter += 1
return [self(y, x)]
def __str__(self):
return "loopy_macro"
class _test_ExpandMacro(unittest.TestCase):
......@@ -423,26 +443,31 @@ class _test_ExpandMacro(unittest.TestCase):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
return [op1(y, x)]
def __str__(self):
return "macro"
x, y, z = inputs()
e = Macro1()(x, y)
g = Env([x, y], [e])
print g
expand_macros.optimize(g)
print g
assert str(g) == "[Op1(y, x)]"
def test_loopy(self):
class Macro1(Macro):
def make_node(self, x, y):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
return [Macro1()(y, x)]
def test_loopy_1(self):
x, y, z = inputs()
e = Macro1()(x, y)
e = LoopyMacro()(x, y)
g = Env([x, y], [e])
TopoOptimizer(ExpandMacro(), ignore_newtrees = True).optimize(g)
assert str(g) == "[loopy_macro(y, x)]"
def test_loopy_2(self):
x, y, z = inputs()
e = LoopyMacro()(x, y)
g = Env([x, y], [e])
print g
#expand_macros.optimize(g)
TopDownOptimizer(ExpandMacro(), ignore_newtrees = True).optimize(g)
print g
try:
TopoOptimizer(ExpandMacro(), ignore_newtrees = False).optimize(g)
self.fail("should not arrive here")
except Exception, e:
if e is not reenter:
raise
......
......@@ -2,6 +2,7 @@
from copy import copy
import graph
import utils
import toolbox
class InconsistencyError(Exception):
......@@ -273,14 +274,13 @@ class Env(utils.object2):
"""
if feature in self._features:
return # the feature is already present
self._features.append(feature)
attach = getattr(feature, 'on_attach', None)
if attach is not None:
try:
attach(self)
except:
self._features.pop()
raise
except toolbox.AlreadyThere:
return
self._features.append(feature)
def remove_feature(self, feature):
"""
......
......@@ -72,7 +72,7 @@ class DestroyHandlerHelper(toolbox.Bookkeeper):
raise Exception("A DestroyHandler instance can only serve one Env.")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(env, attr):
raise Exception("DestroyHandler feature is already present or in conflict with another plugin.")
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
def __destroyers(r):
ret = self.destroyers.get(r, {})
......
差异被折叠。
......@@ -3,6 +3,10 @@ from functools import partial
import graph
class AlreadyThere(Exception):
pass
class Bookkeeper:
def on_attach(self, env):
......@@ -21,7 +25,7 @@ class History:
def on_attach(self, env):
if hasattr(env, 'checkpoint') or hasattr(env, 'revert'):
raise Exception("History feature is already present or in conflict with another plugin.")
raise AlreadyThere("History feature is already present or in conflict with another plugin.")
self.history[env] = []
env.checkpoint = lambda: len(self.history[env])
env.revert = partial(self.revert, env)
......@@ -55,7 +59,7 @@ class Validator:
def on_attach(self, env):
if hasattr(env, 'validate'):
raise Exception("Validator feature is already present or in conflict with another plugin.")
raise AlreadyThere("Validator feature is already present or in conflict with another plugin.")
env.validate = lambda: env.execute_callbacks('validate')
def consistent():
try:
......@@ -77,7 +81,7 @@ class ReplaceValidate(History, Validator):
Validator.on_attach(self, env)
for attr in ('replace_validate', 'replace_all_validate'):
if hasattr(env, attr):
raise Exception("ReplaceValidate feature is already present or in conflict with another plugin.")
raise AlreadyThere("ReplaceValidate feature is already present or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env)
......@@ -110,7 +114,7 @@ class NodeFinder(dict, Bookkeeper):
if self.env is not None:
raise Exception("A NodeFinder instance can only serve one Env.")
if hasattr(env, 'get_nodes'):
raise Exception("NodeFinder is already present or in conflict with another plugin.")
raise AlreadyThere("NodeFinder is already present or in conflict with another plugin.")
self.env = env
env.get_nodes = partial(self.query, env)
Bookkeeper.on_attach(self, env)
......@@ -143,10 +147,7 @@ class NodeFinder(dict, Bookkeeper):
except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op)
all = list(all)
while all:
next = all.pop()
if next in env.nodes:
yield next
return all
class PrintListener(object):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论