提交 8910b8cc authored 作者: bergstrj@iro.umontreal.ca's avatar bergstrj@iro.umontreal.ca

merged again

......@@ -59,6 +59,8 @@ class Function:
features - features to add to the env
optimizer - an optimizer to apply to the copied graph, before linking
linker_cls - a callable that takes an env and returns a Linker
profiler - a Profiler for the produced function (only valid if the
linker_cls's make_function takes a profiler argument)
unpack_single - unpack return value lists of length 1
- see Linker.make_function
keep_locals - add the local variables from __init__ to the class
......@@ -102,16 +104,35 @@ class Function:
self.__dict__.update(locals())
if profiler is None:
self.fn = linker.make_function(inplace=True,
self.fn = linker.make_function(inplace=True,
unpack_single=unpack_single)
else:
self.fn = linker.make_function(inplace=True,
self.fn = linker.make_function(inplace=True,
unpack_single=unpack_single,
profiler=profiler)
self.inputs = inputs
self.outputs = outputs
self.features = features
self.optimizer = optimizer
self.linker_cls = linker_cls
self.profiler = profiler
self.unpack_single = unpack_single
self.except_unreachable_input = except_unreachable_input
self.keep_locals = keep_locals
def __call__(self, *args):
return self.fn(*args)
def __copy__(self):
return Function(self.inputs, self.outputs,
features = self.features,
optimizer = self.optimizer,
linker_cls = self.linker_cls,
profiler = self.profiler,
unpack_single = self.unpack_single,
except_unreachable_input = self.except_unreachable_input,
keep_locals = self.keep_locals)
def eval_outputs(outputs,
features = [],
......
......@@ -157,7 +157,7 @@ class Broadcast(Op, Destroyer):
return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern)
def desc(self):
return (self.__class__, self.scalar_opclass, tuple(self.inplace_pattern.items()))
return (Broadcast, self.scalar_opclass, tuple(self.inplace_pattern.items()))
def destroy_map(self):
ret = {}
......@@ -311,6 +311,9 @@ def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None):
Broadcast.__init__(self, scalar_opclass, inputs, inplace_pattern)
def clone_with_new_inputs(self, *new_inputs):
return New(*new_inputs)
@classmethod
def desc(cls):
return (Broadcast, scalar_opclass, tuple(inplace_pattern.items()))
if name is not None:
New.__name__ = name
else:
......
......@@ -52,6 +52,14 @@ class OpD(MyOp, Destroyer):
def destroyed_inputs(self):
return [self.inputs[0]]
class OpZ(MyOp):
def __init__(self, x, y, a, b):
self.a = a
self.b = b
MyOp.__init__(self, x, y)
def desc(self):
return (self.a, self.b)
import modes
modes.make_constructors(globals())
......@@ -193,8 +201,65 @@ class _test_PatternOptimizer(unittest.TestCase):
'constraint': constraint}),
(Op3, '1')).optimize(g)
assert str(g) == "[Op4(Op3(Op2(x, y)), Op1(Op1(x, y)))]"
def test_match_same(self):
x, y, z = inputs()
e = op1(x, x)
g = env([x, y, z], [e])
PatternOptimizer((Op1, 'x', 'y'),
(Op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(x, x)]"
def test_match_same_illegal(self):
x, y, z = inputs()
e = op2(op1(x, x), op1(x, y))
g = env([x, y, z], [e])
def constraint(env, r):
# Only replacing if the input is an instance of Op2
return r.owner.inputs[0] is not r.owner.inputs[1]
PatternOptimizer({'pattern': (Op1, 'x', 'y'),
'constraint': constraint},
(Op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op2(Op1(x, x), Op3(x, y))]"
def test_multi(self):
x, y, z = inputs()
e0 = op1(x, y)
e = op3(op4(e0), e0)
g = env([x, y, z], [e])
PatternOptimizer((Op4, (Op1, 'x', 'y')),
(Op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(Op4(*1 -> Op1(x, y)), *1)]"
def test_multi_ingraph(self):
x, y, z = inputs()
e0 = op1(x, y)
e = op4(e0, e0)
g = env([x, y, z], [e])
PatternOptimizer((Op4, (Op1, 'x', 'y'), (Op1, 'x', 'y')),
(Op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(x, y)]"
class _test_PatternDescOptimizer(unittest.TestCase):
def test_replace_output(self):
# replacing the whole graph
x, y, z = inputs()
e = op1(op2(x, y), z)
g = env([x, y, z], [e])
PatternDescOptimizer((Op1, (Op2, '1', '2'), '3'),
(Op4, '3', '2')).optimize(g)
assert str(g) == "[Op4(z, y)]"
def test_desc(self):
x, y, z = inputs()
e = op1(op_z(x, y, 37, 88), op2(op_z(y, z, 1, 7)))
g = env([x, y, z], [e])
PatternDescOptimizer(((37, 88), '1', '2'),
(Op3, '2', '1')).optimize(g)
assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]"
class _test_OpSubOptimizer(unittest.TestCase):
......
......@@ -90,7 +90,7 @@ class Env(graph.Graph):
# e.g. z for inputs=(x, y) and outputs=(x + (y - z),)
# We initialize them to the set of outputs; if an output depends on an input,
# it will be removed from the set of orphans.
self._orphans = set(outputs)
self._orphans = set(outputs).difference(inputs)
for feature_class in uniq_features(features):
self.add_feature(feature_class, False)
......
......@@ -284,7 +284,7 @@ class PatternOptimizer(OpSpecificOptimizer):
"""
def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)):
if not issubclass(expr.owner.__class__, pattern[0]) or (self.allow_multiple_clients and not first and env.nclients(expr.owner) > 1):
if not issubclass(expr.owner.__class__, pattern[0]) or (not self.allow_multiple_clients and not first and env.nclients(expr) > 1):
return False
if len(pattern) - 1 != len(expr.owner.inputs):
return False
......@@ -338,16 +338,127 @@ class PatternOptimizer(OpSpecificOptimizer):
self.failure_callback(op.out, new, e)
pass
def __str__(self):
def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (pattern[0].__name__, ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
else:
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
class PatternDescOptimizer(LocalOptimizer):
"""
"""
__env_require__ = toolbox.DescFinder
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None):
"""
Sets in_pattern for replacement by out_pattern.
self.opclass is set to in_pattern[0] to accelerate the search.
"""
self.in_pattern = in_pattern
self.out_pattern = out_pattern
if isinstance(in_pattern, (list, tuple)):
self.desc = self.in_pattern[0]
elif isinstance(in_pattern, dict):
self.desc = self.in_pattern['pattern'][0]
else:
raise TypeError("The pattern to search for must start with a specific desc.")
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
self.failure_callback = failure_callback
self.allow_multiple_clients = allow_multiple_clients
def candidates(self, env):
"""
Returns all instances of self.desc
"""
return env.get_from_desc(self.desc)
def apply_on_op(self, env, op):
"""
Checks if the graph from op corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
If self.failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (results_to_replace, replacement, exception).
If self.allow_multiple_clients is False, he pattern matching will fail
if one of the subpatterns has more than one client.
"""
def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)):
if not expr.owner or not expr.owner.desc() == pattern[0] or (self.allow_multiple_clients and not first and env.nclients(expr.owner) > 1):
return False
if len(pattern) - 1 != len(expr.owner.inputs):
return False
for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u)
if not u:
return False
elif isinstance(pattern, dict):
try:
real_pattern = pattern['pattern']
constraint = pattern['constraint']
except KeyError:
raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern)
if constraint(env, expr):
return match(real_pattern, expr, u, False)
elif isinstance(pattern, str):
v = unify.Var(pattern)
if u[v] is not v and u[v] is not expr:
return False
else:
u = u.merge(expr, v)
elif isinstance(pattern, ResultBase) \
and getattr(pattern, 'constant', False) \
and isinstance(expr, ResultBase) \
and getattr(expr, 'constant', False) \
and pattern.desc() == expr.desc():
return u
else:
return False
return u
def build(pattern, u):
if isinstance(pattern, (list, tuple)):
args = [build(p, u) for p in pattern[1:]]
return pattern[0](*args).out
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
else:
return pattern
u = match(self.in_pattern, op.out, unify.Unification(), True)
if u:
try:
# note: only replaces the default 'out' port if it exists
p = self.out_pattern
new = 'unassigned'
new = build(p, u)
env.replace(op.out, new)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(op.out, new, e)
pass
def __str__(self):
def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (pattern[0], ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
else:
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
class ConstantFinder(Optimizer):
"""
Sets as constant every orphan that is not destroyed
......
......@@ -5,6 +5,7 @@ import utils
__all__ = ['EquivTool',
'InstanceFinder',
'DescFinder',
'PrintListener',
]
......@@ -89,6 +90,36 @@ class InstanceFinder(Listener, Tool, dict):
class DescFinder(Listener, Tool, dict):
def __init__(self, env):
self.env = env
def on_import(self, op):
self.setdefault(op.desc(), set()).add(op)
def on_prune(self, op):
desc = op.desc()
self[desc].remove(op)
if not self[desc]:
del self[desc]
def __query__(self, desc):
all = [x for x in self.get(desc, [])]
shuffle(all) # this helps for debugging because the order of the replacements will vary
while all:
next = all.pop()
if next in self.env.ops():
yield next
def query(self, desc):
return self.__query__(desc)
def publish(self):
self.env.get_from_desc = self.query
class PrintListener(Listener):
def __init__(self, env, active = True):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论