提交 d9e50f16 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added DescFinder and PatternDescOptimizer

上级 a8b455fc
......@@ -52,6 +52,14 @@ class OpD(MyOp, Destroyer):
def destroyed_inputs(self):
return [self.inputs[0]]
class OpZ(MyOp):
def __init__(self, x, y, a, b):
self.a = a
self.b = b
MyOp.__init__(self, x, y)
def desc(self):
return (self.a, self.b)
import modes
modes.make_constructors(globals())
......@@ -195,6 +203,25 @@ class _test_PatternOptimizer(unittest.TestCase):
assert str(g) == "[Op4(Op3(Op2(x, y)), Op1(Op1(x, y)))]"
class _test_PatternDescOptimizer(unittest.TestCase):
def test_replace_output(self):
# replacing the whole graph
x, y, z = inputs()
e = op1(op2(x, y), z)
g = env([x, y, z], [e])
PatternDescOptimizer((Op1, (Op2, '1', '2'), '3'),
(Op4, '3', '2')).optimize(g)
assert str(g) == "[Op4(z, y)]"
def test_desc(self):
x, y, z = inputs()
e = op1(op_z(x, y, 37, 88), op2(op_z(y, z, 1, 7)))
g = env([x, y, z], [e])
PatternDescOptimizer(((37, 88), '1', '2'),
(Op3, '2', '1')).optimize(g)
assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]"
class _test_OpSubOptimizer(unittest.TestCase):
......
......@@ -338,16 +338,127 @@ class PatternOptimizer(OpSpecificOptimizer):
self.failure_callback(op.out, new, e)
pass
def __str__(self):
def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (pattern[0].__name__, ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
else:
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
class PatternDescOptimizer(LocalOptimizer):
"""
"""
__env_require__ = toolbox.DescFinder
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None):
"""
Sets in_pattern for replacement by out_pattern.
self.opclass is set to in_pattern[0] to accelerate the search.
"""
self.in_pattern = in_pattern
self.out_pattern = out_pattern
if isinstance(in_pattern, (list, tuple)):
self.desc = self.in_pattern[0]
elif isinstance(in_pattern, dict):
self.desc = self.in_pattern['pattern'][0]
else:
raise TypeError("The pattern to search for must start with a specific desc.")
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
self.failure_callback = failure_callback
self.allow_multiple_clients = allow_multiple_clients
def candidates(self, env):
"""
Returns all instances of self.desc
"""
return env.get_from_desc(self.desc)
def apply_on_op(self, env, op):
"""
Checks if the graph from op corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
If self.failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (results_to_replace, replacement, exception).
If self.allow_multiple_clients is False, he pattern matching will fail
if one of the subpatterns has more than one client.
"""
def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)):
if not expr.owner.desc() == pattern[0] or (self.allow_multiple_clients and not first and env.nclients(expr.owner) > 1):
return False
if len(pattern) - 1 != len(expr.owner.inputs):
return False
for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u)
if not u:
return False
elif isinstance(pattern, dict):
try:
real_pattern = pattern['pattern']
constraint = pattern['constraint']
except KeyError:
raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern)
if constraint(env, expr):
return match(real_pattern, expr, u, False)
elif isinstance(pattern, str):
v = unify.Var(pattern)
if u[v] is not v and u[v] is not expr:
return False
else:
u = u.merge(expr, v)
elif isinstance(pattern, ResultBase) \
and getattr(pattern, 'constant', False) \
and isinstance(expr, ResultBase) \
and getattr(expr, 'constant', False) \
and pattern.desc() == expr.desc():
return u
else:
return False
return u
def build(pattern, u):
if isinstance(pattern, (list, tuple)):
args = [build(p, u) for p in pattern[1:]]
return pattern[0](*args).out
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
else:
return pattern
u = match(self.in_pattern, op.out, unify.Unification(), True)
if u:
try:
# note: only replaces the default 'out' port if it exists
p = self.out_pattern
new = 'unassigned'
new = build(p, u)
env.replace(op.out, new)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(op.out, new, e)
pass
def __str__(self):
def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (pattern[0], ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
else:
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
class ConstantFinder(Optimizer):
"""
Sets as constant every orphan that is not destroyed
......
......@@ -5,6 +5,7 @@ import utils
__all__ = ['EquivTool',
'InstanceFinder',
'DescFinder',
'PrintListener',
]
......@@ -89,6 +90,36 @@ class InstanceFinder(Listener, Tool, dict):
class DescFinder(Listener, Tool, dict):
def __init__(self, env):
self.env = env
def on_import(self, op):
self.setdefault(op.desc(), set()).add(op)
def on_prune(self, op):
desc = op.desc()
self[desc].remove(op)
if not self[desc]:
del self[desc]
def __query__(self, desc):
all = [x for x in self.get(desc, [])]
shuffle(all) # this helps for debugging because the order of the replacements will vary
while all:
next = all.pop()
if next in self.env.ops():
yield next
def query(self, desc):
return self.__query__(desc)
def publish(self):
self.env.get_from_desc = self.query
class PrintListener(Listener):
def __init__(self, env, active = True):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论