提交 4573a825 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

cleanup

上级 3d9074ac
...@@ -247,26 +247,6 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -247,26 +247,6 @@ class _test_PatternOptimizer(unittest.TestCase):
# assert str(g) == "[Op3(x, y)]" # assert str(g) == "[Op3(x, y)]"
# class _test_PatternDescOptimizer(unittest.TestCase):
# def test_replace_output(self):
# # replacing the whole graph
# x, y, z = inputs()
# e = op1(op2(x, y), z)
# g = env([x, y, z], [e])
# PatternDescOptimizer((Op1, (Op2, '1', '2'), '3'),
# (Op4, '3', '2')).optimize(g)
# assert str(g) == "[Op4(z, y)]"
# def test_eq(self):
# x, y, z = inputs()
# e = op1(op_y(x, y, 37, 88), op2(op_y(y, z, 1, 7)))
# g = env([x, y, z], [e])
# PatternDescOptimizer((op_z, '1', '2'),
# (op3, '2', '1')).optimize(g)
# assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]"
OpSubOptimizer = lambda op1, op2: TopoOptimizer(OpSub(op1, op2)) OpSubOptimizer = lambda op1, op2: TopoOptimizer(OpSub(op1, op2))
OpSubOptimizer = lambda op1, op2: OpKeyOptimizer(OpSub(op1, op2)) OpSubOptimizer = lambda op1, op2: OpKeyOptimizer(OpSub(op1, op2))
...@@ -384,42 +364,6 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -384,42 +364,6 @@ class _test_MergeOptimizer(unittest.TestCase):
# class _test_ConstantFinder(unittest.TestCase):
# def test_straightforward(self):
# x, y, z = inputs()
# y.data = 2
# z.data = 2
# e = op1(x, y, z)
# g = env([x], [e])
# ConstantFinder().optimize(g)
# assert y.constant and z.constant
# MergeOptimizer().optimize(g)
# assert str(g) == "[Op1(x, y, y)]" \
# or str(g) == "[Op1(x, z, z)]"
# def test_deep(self):
# x, y, z = inputs()
# y.data = 2
# z.data = 2
# e = op1(op2(x, y), op2(x, y), op2(x, z))
# g = env([x], [e])
# ConstantFinder().optimize(g)
# assert y.constant and z.constant
# MergeOptimizer().optimize(g)
# assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
# or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
# def test_destroyed_orphan_not_constant(self):
# x, y, z = inputs()
# y.data = 2
# z.data = 2
# e = op_d(x, op2(y, z)) # here x is destroyed by op_d
# g = env([y], [e])
# ConstantFinder().optimize(g)
# assert not getattr(x, 'constant', False) and z.constant
# MergeOptimizer().optimize(g)
reenter = Exception("Re-Entered") reenter = Exception("Re-Entered")
class LoopyMacro(Macro): class LoopyMacro(Macro):
def __init__(self): def __init__(self):
...@@ -448,7 +392,7 @@ class _test_ExpandMacro(unittest.TestCase): ...@@ -448,7 +392,7 @@ class _test_ExpandMacro(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = Macro1()(x, y) e = Macro1()(x, y)
g = Env([x, y], [e]) g = Env([x, y], [e])
expand_macros.optimize(g) ExpandMacros().optimize(g)
assert str(g) == "[Op1(y, x)]" assert str(g) == "[Op1(y, x)]"
def test_loopy_1(self): def test_loopy_1(self):
......
...@@ -100,14 +100,15 @@ class _test_NodeFinder(unittest.TestCase): ...@@ -100,14 +100,15 @@ class _test_NodeFinder(unittest.TestCase):
if not len([x for x in g.get_nodes(type)]) == num: if not len([x for x in g.get_nodes(type)]) == num:
self.fail((type, num)) self.fail((type, num))
def test_robustness(self): # def test_robustness(self):
x, y, z = inputs() # # this test used to make sense to have, but it doesn't work like that anymore
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z))) # x, y, z = inputs()
g = Env([x, y, z], [e]) # e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
g.extend(NodeFinder()) # g = Env([x, y, z], [e])
gen = g.get_nodes(sigmoid) # I want to get Sigmoid instances # g.extend(NodeFinder())
g.replace(e, add(x, y)) # but here I prune them all # gen = g.get_nodes(sigmoid) # I want to get Sigmoid instances
assert len([x for x in gen]) == 0 # the generator should not yield them # g.replace(e, add(x, y)) # but here I prune them all
# assert len([x for x in gen]) == 0 # the generator should not yield them
......
...@@ -128,9 +128,7 @@ class MergeOptimizer(Optimizer): ...@@ -128,9 +128,7 @@ class MergeOptimizer(Optimizer):
""" """
def add_requirements(self, env): def add_requirements(self, env):
try:
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
except: pass
def apply(self, env): def apply(self, env):
cid = _metadict() #result -> result.desc() (for constants) cid = _metadict() #result -> result.desc() (for constants)
...@@ -139,7 +137,7 @@ class MergeOptimizer(Optimizer): ...@@ -139,7 +137,7 @@ class MergeOptimizer(Optimizer):
sig = r.signature() sig = r.signature()
other_r = inv_cid.get(sig, None) other_r = inv_cid.get(sig, None)
if other_r is not None: if other_r is not None:
env.replace(r, other_r) env.replace_validate(r, other_r)
else: else:
cid[r] = sig cid[r] = sig
inv_cid[sig] = r inv_cid[sig] = r
...@@ -559,6 +557,10 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -559,6 +557,10 @@ class OpKeyOptimizer(NavigatorOptimizer):
env.extend(toolbox.NodeFinder()) env.extend(toolbox.NodeFinder())
def keep_going(exc, nav, repl_pairs):
pass
############################## ##############################
### Pre-defined optimizers ### ### Pre-defined optimizers ###
############################## ##############################
...@@ -567,431 +569,3 @@ def ExpandMacros(filter = None): ...@@ -567,431 +569,3 @@ def ExpandMacros(filter = None):
return TopoOptimizer(ExpandMacro(filter = filter), return TopoOptimizer(ExpandMacro(filter = filter),
order = 'in_to_out', order = 'in_to_out',
ignore_newtrees = False) ignore_newtrees = False)
# class TopoOptimizer(Optimizer):
# def __init__(self, local_opt, order = 'out_to_in', ignore_newtrees = False, failure_callback = None):
# self.local_opt = local_opt
# if order not in ['out_to_in', 'in_to_out']:
# raise ValueError("order must be 'out_to_in' or 'in_to_out'")
# self.order = order
# self.ignore_newtrees = ignore_newtrees
# self.failure_callback = failure_callback
# def apply(self, env):
# ignore_newtrees = self.ignore_newtrees
# q = deque()
# class Updater:
# def on_attach(self, env):
# for node in graph.io_toposort(env.inputs, env.outputs):
# q.append(node)
# if not ignore_newtrees:
# def on_import(self, env, node):
# q.append(node)
# def on_prune(self, env, node):
# if node is not current_node:
# q.remove(node)
# u = Updater()
# env.extend(u)
# while q:
# if self.order == 'out_to_in':
# node = q.pop()
# else:
# node = q.popleft()
# current_node = node
# if not self.local_opt.applies(node):
# continue
# replacements = self.local_opt.transform(node)
# repl_pairs = zip(node.outputs, replacements)
# try:
# env.replace_all_validate(repl_pairs)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(e, self, repl_pairs)
# else:
# raise
# env.remove_feature(u)
# def add_requirements(self, env):
# try:
# env.extend(toolbox.ReplaceValidate())
# except: pass
# class OpKeyOptimizer(Optimizer):
# def __init__(self, local_opt, ignore_newtrees = False, failure_callback = None):
# self.local_opt = local_opt
# if not hasattr(local_opt, 'op_key'):
# raise TypeError("LocalOptimizer for OpKeyOptimizer must have an 'op_key' method.")
# self.ignore_newtrees = ignore_newtrees
# self.failure_callback = failure_callback
# def apply(self, env):
# ignore_newtrees = self.ignore_newtrees
# op = self.local_opt.op_key()
# q = []
# class Updater:
# def on_attach(self, env):
# for node in graph.io_toposort(env.inputs, env.outputs):
# if node.op == op: q.append(node)
# if not ignore_newtrees:
# def on_import(self, env, node):
# if node.op == op: q.append(node)
# def on_prune(self, env, node):
# if node is not current_node:
# q.remove(node)
# u = Updater()
# env.extend(u)
# q = list(env.get_nodes(op))
# while q:
# node = q.pop()
# current_node = node
# if not self.local_opt.applies(node):
# continue
# replacements = self.local_opt.transform(node)
# repl_pairs = zip(node.outputs, replacements)
# try:
# env.replace_all_validate(repl_pairs)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(e, self, repl_pairs)
# else:
# raise
# env.remove_feature(u)
# def add_requirements(self):
# """
# Requires the following features:
# - NodeFinder
# - ReplaceValidate
# """
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# class OpSpecificOptimizer(LocalOptimizer):
# """
# Generic L{Optimizer} that applies only to ops of a certain
# type. The type in question is accessed through L{self.op}.
# op can also be a class variable of the subclass.
# """
# def add_requirements(self, env):
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# def candidates(self, env):
# """
# Returns all nodes that have L{self.op} in their op field.
# """
# return env.get_nodes(self.op)
# class OpSubOptimizer(Optimizer):
# """
# Replaces all applications of a certain op by the application of
# another op that take the same inputs as what they are replacing.
# e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
# OpSubOptimizer requires the following features:
# - NodeFinder
# - ReplaceValidate
# """
# def add_requirements(self, env):
# """
# Requires the following features:
# - NodeFinder
# - ReplaceValidate
# """
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# def __init__(self, op1, op2, failure_callback = None):
# """
# op1.make_node and op2.make_node must take the same number of
# inputs and have the same number of outputs.
# If failure_callback is not None, it will be called whenever
# the Optimizer fails to do a replacement in the graph. The
# arguments to the callback are: (node, replacement, exception)
# """
# self.op1 = op1
# self.op2 = op2
# self.failure_callback = failure_callback
# def apply(self, env):
# """
# Replaces all applications of self.op1 by applications of self.op2
# with the same inputs.
# """
# candidates = env.get_nodes(self.op1)
# for node in candidates:
# try:
# repl = self.op2.make_node(*node.inputs)
# assert len(node.outputs) == len(repl.outputs)
# for old, new in zip(node.outputs, repl.outputs):
# env.replace_validate(old, new)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(node, repl, e)
# def str(self):
# return "%s -> %s" % (self.op1, self.op2)
# class OpRemover(Optimizer):
# """
# @todo untested
# Removes all applications of an op by transferring each of its
# outputs to the corresponding input.
# """
# def add_requirements(self, env):
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# def __init__(self, op, failure_callback = None):
# """
# Applications of the op must have as many inputs as outputs.
# If failure_callback is not None, it will be called whenever
# the Optimizer fails to remove an operation in the graph. The
# arguments to the callback are: (node, exception)
# """
# self.op = op
# self.failure_callback = failure_callback
# def apply(self, env):
# """
# Removes all applications of self.op.
# """
# candidates = env.get_nodes(self.op)
# for node in candidates:
# try:
# assert len(node.inputs) == len(node.outputs)
# for input, output in zip(node.inputs, node.outputs):
# env.replace(output, input)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(node, e)
# pass
# def str(self):
# return "f(%s(x)) -> f(x)" % self.op
# class PatternOptimizer(OpSpecificOptimizer):
# """
# @todo update
# Replaces all occurrences of the input pattern by the output pattern:
# input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
# input_pattern ::= dict(pattern = <input_pattern>,
# constraint = <constraint>)
# sub_pattern ::= input_pattern
# sub_pattern ::= string
# sub_pattern ::= a Constant instance
# constraint ::= lambda env, expr: additional matching condition
# output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
# output_pattern ::= string
# Each string in the input pattern is a variable that will be set to
# whatever expression is found in its place. If the same string is
# used more than once, the same expression must be found in those
# places. If a string used in the input pattern is used in the
# output pattern, the matching expression will be inserted in its
# place. The input pattern cannot just be a string but the output
# pattern can.
# If you put a constant result in the input pattern, there will be a
# match iff a constant result with the same value and the same type
# is found in its place.
# You can add a constraint to the match by using the dict(...) form
# described above with a 'constraint' key. The constraint must be a
# function that takes the env and the current Result that we are
# trying to match and returns True or False according to an
# arbitrary criterion.
# Examples:
# PatternOptimizer((add, 'x', 'y'), (add, 'y', 'x'))
# PatternOptimizer((multiply, 'x', 'x'), (square, 'x'))
# PatternOptimizer((subtract, (add, 'x', 'y'), 'y'), 'x')
# PatternOptimizer((power, 'x', Constant(double, 2.0)), (square, 'x'))
# PatternOptimizer((boggle, {'pattern': 'x',
# 'constraint': lambda env, expr: expr.type == scrabble}),
# (scrabble, 'x'))
# """
# def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None):
# """
# Creates a PatternOptimizer that replaces occurrences of
# in_pattern by occurrences of out_pattern.
# If failure_callback is not None, if there is a match but a
# replacement fails to occur, the callback will be called with
# arguments (result_to_replace, replacement, exception).
# If allow_multiple_clients is False, he pattern matching will
# fail if one of the subpatterns has more than one client.
# """
# self.in_pattern = in_pattern
# self.out_pattern = out_pattern
# if isinstance(in_pattern, (list, tuple)):
# self.op = self.in_pattern[0]
# elif isinstance(in_pattern, dict):
# self.op = self.in_pattern['pattern'][0]
# else:
# raise TypeError("The pattern to search for must start with a specific Op instance.")
# self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
# self.failure_callback = failure_callback
# self.allow_multiple_clients = allow_multiple_clients
# def apply_on_node(self, env, node):
# """
# Checks if the graph from node corresponds to in_pattern. If it does,
# constructs out_pattern and performs the replacement.
# """
# def match(pattern, expr, u, first = False):
# if isinstance(pattern, (list, tuple)):
# if expr.owner is None:
# return False
# if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and env.nclients(expr) > 1):
# return False
# if len(pattern) - 1 != len(expr.owner.inputs):
# return False
# for p, v in zip(pattern[1:], expr.owner.inputs):
# u = match(p, v, u)
# if not u:
# return False
# elif isinstance(pattern, dict):
# try:
# real_pattern = pattern['pattern']
# constraint = pattern['constraint']
# except KeyError:
# raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern)
# if constraint(env, expr):
# return match(real_pattern, expr, u, False)
# elif isinstance(pattern, str):
# v = unify.Var(pattern)
# if u[v] is not v and u[v] is not expr:
# return False
# else:
# u = u.merge(expr, v)
# elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr):
# return u
# else:
# return False
# return u
# def build(pattern, u):
# if isinstance(pattern, (list, tuple)):
# args = [build(p, u) for p in pattern[1:]]
# return pattern[0](*args)
# elif isinstance(pattern, str):
# return u[unify.Var(pattern)]
# else:
# return pattern
# u = match(self.in_pattern, node.out, unify.Unification(), True)
# if u:
# try:
# # note: only replaces the default 'out' port if it exists
# p = self.out_pattern
# new = 'unassigned' # this is for the callback if build fails
# new = build(p, u)
# env.replace(node.out, new)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(node.out, new, e)
# pass
# def __str__(self):
# def pattern_to_str(pattern):
# if isinstance(pattern, (list, tuple)):
# return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]]))
# elif isinstance(pattern, dict):
# return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
# else:
# return str(pattern)
# return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
# class LocalOptimizer(Optimizer):
# """
# Generic L{Optimizer} class that considers local parts of
# the L{Env}. It must be subclassed and should override the
# following two methods:
# - candidates(env) -> returns a set of ops that can be
# optimized
# - apply_on_node(env, node) -> for each node in candidates,
# this function will be called to perform the actual
# optimization.
# """
# def candidates(self, env):
# """
# Must return a set of nodes that can be optimized.
# """
# raise utils.AbstractFunctionError()
# def apply_on_node(self, env, node):
# """
# For each node in candidates, this function will be called to
# perform the actual optimization.
# """
# raise utils.AbstractFunctionError()
# def apply(self, env):
# """
# Calls self.apply_on_op(env, op) for each op in self.candidates(env).
# """
# for node in self.candidates(env):
# if node in env.nodes:
# self.apply_on_node(env, node)
...@@ -34,8 +34,9 @@ class scratchpad: ...@@ -34,8 +34,9 @@ class scratchpad:
self.__dict__.clear() self.__dict__.clear()
def __update__(self, other): def __update__(self, other):
self.__dict__.update(other.__dict__) self.__dict__.update(other.__dict__)
return self
def __str__(self): def __str__(self):
print "scratch" + str(self.__dict__) return "scratch" + str(self.__dict__)
def memoize(f): def memoize(f):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论