提交 14090d83 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

documentation and cleanup

上级 abd860d7
......@@ -4,11 +4,15 @@ import unittest
from link import PerformLinker, Profiler
from cc import *
from type import Type
from graph import Result, as_result, Apply, Constant
from graph import Result, Apply, Constant
from op import Op
import env
import toolbox
def as_result(x):
assert isinstance(x, Result)
return x
class TDouble(Type):
def filter(self, data):
return float(data)
......
......@@ -3,18 +3,20 @@ import unittest
from type import Type
import graph
from graph import Result, as_result, Apply
from graph import Result, Apply
from op import Op
from opt import PatternOptimizer, OpSubOptimizer
from ext import *
from env import Env, InconsistencyError
#from toolbox import EquivTool
from toolbox import ReplaceValidate
from copy import copy
#from _test_result import MyResult
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type):
......
......@@ -15,6 +15,10 @@ else:
realtestcase = unittest.TestCase
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type):
......
......@@ -3,7 +3,7 @@
import unittest
import graph
from graph import Result, as_result, Apply, Constant
from graph import Result, Apply, Constant
from type import Type
from op import Op
import env
......@@ -14,6 +14,10 @@ from link import *
#from _test_result import Double
def as_result(x):
assert isinstance(x, Result)
return x
class TDouble(Type):
def filter(self, data):
return float(data)
......
......@@ -3,9 +3,11 @@ import unittest
from copy import copy
from op import *
from type import Type, Generic
from graph import Apply, as_result
from graph import Apply, Result
#from result import Result
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type):
......
import unittest
from graph import Result, as_result, Apply, Constant
from type import Type
from graph import Result, Apply, Constant
from op import Op
from opt import *
from env import Env
from toolbox import *
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type):
......
import unittest
from graph import Result, as_result, Apply
from graph import Result, Apply
from type import Type
from op import Op
#from opt import PatternOptimizer, OpSubOptimizer
from env import Env, InconsistencyError
from toolbox import *
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type):
def __init__(self, name):
......
"""
Defines Linkers that deal with C implementations.
"""
import graph
from graph import Constant, Value
from link import Linker, LocalLinker, raise_with_op, Filter, map_storage, PerformLinker
# Python imports
from copy import copy
from utils import AbstractFunctionError
from env import Env
import md5
import sys
import os
import platform
import re
import os, sys, platform
# weave import
from scipy import weave
# gof imports
import cutils
from env import Env
import graph
import link
import utils
import traceback
import re
def compile_dir():
......@@ -192,7 +196,7 @@ def struct_gen(args, struct_builders, blocks, sub):
return %(failure_var)s;
""" % sub
sub = copy(sub)
sub = dict(sub)
sub.update(locals())
# TODO: add some error checking to make sure storage_<x> are 1-element lists
......@@ -309,7 +313,7 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
name = "V%i" % id
symbol_table[result] = name
sub = copy(sub)
sub = dict(sub)
# sub['name'] = name
sub['id'] = id
sub['fail'] = failure_code(sub)
......@@ -323,13 +327,16 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
return struct_builder, block
class CLinker(Linker):
class CLinker(link.Linker):
"""
Creates C code for an env or an Op instance, compiles it and returns
callables through make_thunk and make_function that make use of the
compiled code.
Creates C code for an env, compiles it and returns callables
through make_thunk and make_function that make use of the compiled
code.
It can take an env or an Op as input.
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
def __init__(self, env, no_recycling = []):
......@@ -346,7 +353,7 @@ class CLinker(Linker):
self.outputs = env.outputs
self.results = graph.results(self.inputs, self.outputs) # list(env.results)
# The orphans field is listified to ensure a consistent order.
self.orphans = list(r for r in self.results if isinstance(r, Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs))
self.orphans = list(r for r in self.results if isinstance(r, graph.Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs))
self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans))
self.node_order = env.toposort()
......@@ -404,15 +411,15 @@ class CLinker(Linker):
policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_extract, get_c_cleanup]]
elif result in self.orphans:
if not isinstance(result, Value):
if not isinstance(result, graph.Value):
raise TypeError("All orphans to CLinker must be Value instances.", result)
if isinstance(result, Constant):
if isinstance(result, graph.Constant):
try:
symbol[result] = "(" + result.type.c_literal(result.data) + ")"
consts.append(result)
self.orphans.remove(result)
continue
except (AbstractFunctionError, NotImplementedError):
except (utils.AbstractFunctionError, NotImplementedError):
pass
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same
policy = [[get_c_declare, get_c_extract, get_c_cleanup],
......@@ -475,11 +482,11 @@ class CLinker(Linker):
op = node.op
try: behavior = op.c_code(node, name, isyms, osyms, sub)
except AbstractFunctionError:
except utils.AbstractFunctionError:
raise NotImplementedError("%s cannot produce C code" % op)
try: cleanup = op.c_code_cleanup(node, name, isyms, osyms, sub)
except AbstractFunctionError:
except utils.AbstractFunctionError:
cleanup = ""
blocks.append(CodeBlock("", behavior, cleanup, sub))
......@@ -539,7 +546,7 @@ class CLinker(Linker):
ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret.append(x.c_support_code())
except AbstractFunctionError: pass
except utils.AbstractFunctionError: pass
return ret
def compile_args(self):
......@@ -552,7 +559,7 @@ class CLinker(Linker):
ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_compile_args()
except AbstractFunctionError: pass
except utils.AbstractFunctionError: pass
return ret
def headers(self):
......@@ -565,7 +572,7 @@ class CLinker(Linker):
ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_headers()
except AbstractFunctionError: pass
except utils.AbstractFunctionError: pass
return ret
def libraries(self):
......@@ -578,25 +585,23 @@ class CLinker(Linker):
ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_libraries()
except AbstractFunctionError: pass
except utils.AbstractFunctionError: pass
return ret
def __compile__(self, input_storage = None, output_storage = None):
"""
@todo update
Compiles this linker's env. If inplace is True, it will use the
Results contained in the env, if it is False it will copy the
input and output Results.
Returns: thunk, in_results, out_results, error_storage
Compiles this linker's env.
@type input_storage: list or None
@param input_storage: list of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated.
@param output_storage: list of lists of length 1. The thunk returned
by __compile__ will put the results of the computation in these
lists. If None, storage will be allocated.
Returns: thunk, input_storage, output_storage, error_storage
"""
# if inplace:
# in_results = self.inputs
# out_results = self.outputs
# else:
# in_results = [copy(input) for input in self.inputs]
# out_results = [copy(output) for output in self.outputs]
error_storage = [None, None, None]
if input_storage is None:
input_storage = tuple([None] for result in self.inputs)
......@@ -612,13 +617,34 @@ class CLinker(Linker):
thunk = self.cthunk_factory(error_storage,
input_storage,
output_storage)
return thunk, [Filter(input.type, storage) for input, storage in zip(self.env.inputs, input_storage)], \
[Filter(output.type, storage, True) for output, storage in zip(self.env.outputs, output_storage)], \
return thunk, \
[link.Filter(input.type, storage) for input, storage in zip(self.env.inputs, input_storage)], \
[link.Filter(output.type, storage, True) for output, storage in zip(self.env.outputs, output_storage)], \
error_storage
# return thunk, [Filter(x) for x in input_storage], [Filter(x) for x in output_storage], error_storage
def make_thunk(self, input_storage = None, output_storage = None):
"""
Compiles this linker's env and returns a function to perform the
computations, as well as lists of storage cells for both the
inputs and outputs.
@type input_storage: list or None
@param input_storage: list of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated.
@param output_storage: list of lists of length 1. The thunk returned
by __compile__ will put the results of the computation in these
lists. If None, storage will be allocated.
Returns: thunk, input_storage, output_storage
The return values can be used as follows:
f, istor, ostor = clinker.make_thunk()
istor[0].data = first_input
istor[1].data = second_input
f()
first_output = ostor[0].data
"""
cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage)
def execute():
failure = cutils.run_cthunk(cthunk)
......@@ -729,7 +755,7 @@ class CLinker(Linker):
class OpWiseCLinker(LocalLinker):
class OpWiseCLinker(link.LocalLinker):
"""
Uses CLinker on the individual Ops that comprise an env and loops
over them in Python. The result is slower than a compiled version of
......@@ -739,6 +765,10 @@ class OpWiseCLinker(LocalLinker):
If fallback_on_perform is True, OpWiseCLinker will use an op's
perform method if no C version can be generated.
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
def __init__(self, env, fallback_on_perform = True, no_recycling = []):
......@@ -756,7 +786,7 @@ class OpWiseCLinker(LocalLinker):
order = env.toposort()
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(env, order, input_storage, output_storage)
input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage)
thunks = []
for node in order:
......@@ -772,7 +802,7 @@ class OpWiseCLinker(LocalLinker):
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunks.append(thunk)
except (NotImplementedError, AbstractFunctionError):
except (NotImplementedError, utils.AbstractFunctionError):
if self.fallback_on_perform:
p = node.op.perform
thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o)
......@@ -791,27 +821,8 @@ class OpWiseCLinker(LocalLinker):
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
# if profiler is None:
# def f():
# for x in no_recycling:
# x[0] = None
# try:
# for thunk, node in zip(thunks, order):
# thunk()
# except:
# raise_with_op(node)
# else:
# def f():
# for x in no_recycling:
# x[0] = None
# def g():
# for thunk, node in zip(thunks, order):
# profiler.profile_op(thunk, node)
# profiler.profile_env(g, env)
# f.profiler = profiler
return f, [Filter(input.type, storage) for input, storage in zip(env.inputs, input_storage)], \
[Filter(output.type, storage, True) for output, storage in zip(env.outputs, output_storage)], \
return f, [link.Filter(input.type, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Filter(output.type, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks, order
......@@ -825,7 +836,7 @@ def _default_checker(x, y):
if x[0] != y[0]:
raise Exception("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]})
class DualLinker(Linker):
class DualLinker(link.Linker):
"""
Runs the env in parallel using PerformLinker and CLinker.
......@@ -841,13 +852,13 @@ class DualLinker(Linker):
"""
Initialize a DualLinker.
The checker argument must be a function that takes two Result
instances. The first one passed will be the output computed by
PerformLinker and the second one the output computed by
OpWiseCLinker. The checker should compare the data fields of
the two results to see if they match. By default, DualLinker
uses ==. A custom checker can be provided to compare up to a
certain error tolerance.
The checker argument must be a function that takes two lists
of length 1. The first one passed will contain the output
computed by PerformLinker and the second one the output
computed by OpWiseCLinker. The checker should compare the data
fields of the two results to see if they match. By default,
DualLinker uses ==. A custom checker can be provided to
compare up to a certain error tolerance.
If a mismatch occurs, the checker should raise an exception to
halt the computation. If it does not, the computation will
......@@ -855,35 +866,22 @@ class DualLinker(Linker):
the problem by fiddling with the data, but it should be
careful not to share data between the two outputs (or inplace
operations that use them will interfere).
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
self.env = env
self.checker = checker
self.no_recycling = no_recycling
def make_thunk(self, **kwargs):
# if inplace:
# env1 = self.env
# else:
# env1 = self.env.clone(True)
# env2, equiv = env1.clone_get_equiv(True)
# op_order_1 = env1.toposort()
# op_order_2 = [equiv[op.outputs[0]].owner for op in op_order_1] # we need to have the exact same order so we can compare each step
# def c_make_thunk(op):
# try:
# return CLinker(op).make_thunk(True)[0]
# except AbstractFunctionError:
# return op.perform
# thunks1 = [op.perform for op in op_order_1]
# thunks2 = [c_make_thunk(op) for op in op_order_2]
env = self.env
no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = PerformLinker(env, no_recycling = no_recycling).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = OpWiseCLinker(env, no_recycling = no_recycling).make_all(**kwargs)
_f, i1, o1, thunks1, order1 = link.PerformLinker(env, no_recycling = no_recycling).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = OpWiseCLinker(env, no_recycling = no_recycling).make_all(**kwargs)
def f():
for input1, input2 in zip(i1, i2):
......@@ -903,15 +901,7 @@ class DualLinker(Linker):
for output1, output2 in zip(thunk1.outputs, thunk2.outputs):
self.checker(output1, output2)
except:
raise_with_op(node1)
# exc_type, exc_value, exc_trace = sys.exc_info()
# try:
# trace = op1.trace
# except AttributeError:
# trace = ()
# exc_value.__thunk_trace__ = trace
# exc_value.args = exc_value.args + (op1, )
# raise exc_type, exc_value, exc_trace
link.raise_with_op(node1)
return f, i1, o1
......
from copy import copy
import graph
##from features import Listener, Orderings, Constraint, Tool, uniq_features
import utils
from utils import AbstractFunctionError
class InconsistencyError(Exception):
"""
This exception is raised by Env whenever one of the listeners marks
the graph as inconsistent.
This exception should be thrown by listeners to Env when the
graph's state is invalid.
"""
pass
class Env(object): #(graph.Graph):
class Env(utils.object2):
"""
An Env represents a subgraph bound by a set of input results and a
set of output results. An op is in the subgraph iff it depends on
the value of some of the Env's inputs _and_ some of the Env's
outputs depend on it. A result is in the subgraph iff it is an
input or an output of an op that is in the subgraph.
set of output results. The inputs list should contain all the inputs
on which the outputs depend. Results of type Value or Constant are
not counted as inputs.
The Env supports the replace operation which allows to replace a
result in the subgraph by another, e.g. replace (x + x).out by (2
* x).out. This is the basis for optimization in theano.
It can also be "extended" using env.extend(some_object). See the
toolbox and ext modules for common extensions.
"""
### Special ###
......@@ -65,12 +64,14 @@ class Env(object): #(graph.Graph):
### Setup a Result ###
def __setup_r__(self, r):
# sets up r so it belongs to this env
if hasattr(r, 'env') and r.env is not None and r.env is not self:
raise Exception("%s is already owned by another env" % r)
r.env = self
r.clients = []
def __setup_node__(self, node):
# sets up node so it belongs to this env
if hasattr(node, 'env') and node.env is not self:
raise Exception("%s is already owned by another env" % node)
node.env = self
......@@ -80,28 +81,27 @@ class Env(object): #(graph.Graph):
### clients ###
def clients(self, r):
"Set of all the (op, i) pairs such that op.inputs[i] is r."
"Set of all the (node, i) pairs such that node.inputs[i] is r."
return r.clients
def __add_clients__(self, r, all):
def __add_clients__(self, r, new_clients):
"""
r -> result
all -> list of (op, i) pairs representing who r is an input of.
new_clients -> list of (node, i) pairs such that node.inputs[i] is r.
Updates the list of clients of r with all.
Updates the list of clients of r with new_clients.
"""
r.clients += all
r.clients += new_clients
def __remove_clients__(self, r, all, prune = True):
def __remove_clients__(self, r, clients_to_remove, prune = True):
"""
r -> result
all -> list of (op, i) pairs representing who r is an input of.
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
Removes all from the clients list of r.
"""
for entry in all:
for entry in clients_to_remove:
r.clients.remove(entry)
# remove from orphans?
if not r.clients:
if prune:
self.__prune_r__([r])
......@@ -188,6 +188,15 @@ class Env(object): #(graph.Graph):
### change input ###
def change_input(self, node, i, new_r):
"""
Changes node.inputs[i] to new_r.
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(env, node, i, old_r, new_r)
"""
if node == 'output':
r = self.outputs[i]
if not r.type == new_r.type:
......@@ -214,10 +223,7 @@ class Env(object): #(graph.Graph):
def replace(self, r, new_r):
"""
This is the main interface to manipulate the subgraph in Env.
For every op that uses r as input, makes it use new_r instead.
This may raise an error if the new result violates type
constraints for one of the target nodes. In that case, no
changes are made.
For every node that uses r as input, makes it use new_r instead.
"""
if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r)
......@@ -238,11 +244,32 @@ class Env(object): #(graph.Graph):
def extend(self, feature):
"""
@todo out of date
Adds an instance of the feature_class to this env's supported
features. If do_import is True and feature_class is a subclass
of Listener, its on_import method will be called on all the Nodes
already in the env.
Adds a feature to this env. The feature may define one
or more of the following methods:
- feature.on_attach(env)
Called by extend. The feature has great freedom in what
it can do with the env: it may, for example, add methods
to it dynicamically.
- feature.on_detach(env)
Called by remove_feature(feature).
- feature.on_import(env, node)*
Called whenever a node is imported into env, which is
just before the node is actually connected to the graph.
- feature.on_prune(env, node)*
Called whenever a node is pruned (removed) from the env,
after it is disconnected from the graph.
- feature.on_change_input(env, node, i, r, new_r)*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(env)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
"""
if feature in self._features:
return # the feature is already present
......@@ -256,6 +283,11 @@ class Env(object): #(graph.Graph):
raise
def remove_feature(self, feature):
"""
Removes the feature from the graph.
Calls feature.on_detach(env) if an on_detach method is defined.
"""
try:
self._features.remove(feature)
except:
......@@ -268,6 +300,11 @@ class Env(object): #(graph.Graph):
### callback utils ###
def execute_callbacks(self, name, *args):
"""
Calls
getattr(feature, name)(*args)
for each feature which has a method called after name.
"""
for feature in self._features:
try:
fn = getattr(feature, name)
......@@ -276,6 +313,11 @@ class Env(object): #(graph.Graph):
fn(self, *args)
def collect_callbacks(self, name, *args):
"""
Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name.
"""
d = {}
for feature in self._features:
try:
......@@ -289,6 +331,17 @@ class Env(object): #(graph.Graph):
### misc ###
def toposort(self):
"""
Returns an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
an 'orderings' method.
If a feature has an 'orderings' method, it will be called with
this env as sole argument. It should return a dictionary of
{node: predecessors} where predecessors is a list of nodes
that should be computed before the key node.
"""
env = self
ords = {}
for feature in env._features:
......@@ -314,10 +367,10 @@ class Env(object): #(graph.Graph):
raise Exception("what the fuck")
return node.inputs
def has_node(self, node):
return node in self.nodes
def check_integrity(self):
"""
Call this for a diagnosis if things go awry.
"""
nodes = graph.ops(self.inputs, self.outputs)
if self.nodes != nodes:
missing = nodes.difference(self.nodes)
......
#from features import Listener, Constraint, Orderings, Tool
from collections import defaultdict
import graph
import utils
from utils import AbstractFunctionError
import toolbox
from copy import copy
from utils import AbstractFunctionError
from env import InconsistencyError
from toolbox import Bookkeeper
from collections import defaultdict
class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
class DestroyHandler(toolbox.Bookkeeper):
"""
This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over
......@@ -29,14 +21,16 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
Examples:
- (x += 1) + (x += 1) -> fails because the first += makes the second
invalid
- x += transpose_view(x) -> fails because the input that is destroyed
depends on an input that shares the same data
- (a += b) + (c += a) -> succeeds but we have to do c += a first
- (a += b) + (b += c) + (c += a) -> fails because there's a cyclical
dependency (no possible ordering)
This feature allows some optimizations (eg sub += for +) to be applied
safely.
@todo
- x += transpose_view(x) -> fails because the input that is destroyed
depends on an input that shares the same data
"""
def __init__(self):
......@@ -88,11 +82,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
self.seen = set()
Bookkeeper.on_attach(self, env)
# # Initialize the children if the inputs and orphans.
# for input in env.inputs: # env.orphans.union(env.inputs):
# self.children[input] = set()
toolbox.Bookkeeper.on_attach(self, env)
def on_detach(self, env):
del self.parent
......@@ -105,19 +95,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
del self.seen
self.env = None
# def publish(self):
# """
# Publishes the following on the env:
# - destroyers(r) -> returns all L{Op}s that destroy the result r
# - destroy_handler -> self
# """
# def __destroyers(r):
# ret = self.destroyers.get(r, {})
# ret = ret.keys()
# return ret
# self.env.destroyers = __destroyers
# self.env.destroy_handler = self
def __path__(self, r):
"""
Returns a path from r to the result that it is ultimately
......@@ -171,7 +148,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def __pre__(self, op):
"""
Returns all results that must be computed prior to computing
this op.
this node.
"""
rval = set()
if op is None:
......@@ -222,7 +199,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
users = set(self.__users__(start))
users.add(start)
for user in users:
for cycle in copy(self.cycles):
for cycle in set(self.cycles):
if user in cycle:
self.cycles.remove(cycle)
if just_remove:
......@@ -234,7 +211,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
"""
@return: (vmap, dmap) where:
- vmap -> {output : [inputs output is a view of]}
- dmap -> {output : [inputs that are destroyed by the Op
- dmap -> {output : [inputs that are destroyed by the node
(and presumably returned as that output)]}
"""
try: _vmap = node.op.view_map
......@@ -260,7 +237,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def on_import(self, env, op):
"""
Recomputes the dependencies and search for inconsistencies given
that we just added an op to the env.
that we just added an node to the env.
"""
self.seen.add(op)
......@@ -303,7 +280,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def on_prune(self, env, op):
"""
Recomputes the dependencies and searches for inconsistencies to remove
given that we just removed an op to the env.
given that we just removed a node to the env.
"""
view_map, destroy_map = self.get_maps(op)
......@@ -400,7 +377,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
"""
Recomputes the dependencies and searches for inconsistencies to remove
given that all the clients are moved from r_1 to r_2, clients being
a list of (op, i) pairs such that op.inputs[i] used to be r_1 and is
a list of (node, i) pairs such that node.inputs[i] used to be r_1 and is
now r_2.
"""
path_1 = self.__path__(r_1)
......@@ -485,53 +462,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
# class Destroyer:
# """
# Base class for Ops that destroy one or more of their inputs in an
# inplace operation, use them as temporary storage, puts garbage in
# them or anything else that invalidates the contents for use by other
# Ops.
# Usage of this class in an env requires DestroyHandler.
# """
# def destroyed_inputs(self):
# raise AbstractFunctionError()
# def destroy_map(self):
# """
# Returns the map {output: [list of destroyed inputs]}
# While it typically means that the storage of the output is
# shared with each of the destroyed inputs, it does necessarily
# have to be the case.
# """
# # compatibility
# return {self.out: self.destroyed_inputs()}
# __env_require__ = DestroyHandler
# class Viewer:
# """
# Base class for Ops that return one or more views over one or more inputs,
# which means that the inputs and outputs share their storage. Unless it also
# extends Destroyer, this Op does not modify the storage in any way and thus
# the input is safe for use by other Ops even after executing this one.
# """
# def view_map(self):
# """
# Returns the map {output: [list of viewed inputs]}
# It means that the output shares storage with each of the inputs
# in the list.
# Note: support for more than one viewed input is minimal, but
# this might improve in the future.
# """
# raise AbstractFunctionError()
def view_roots(r):
"""
Utility function that returns the leaves of a search through
......
......@@ -3,20 +3,9 @@ from copy import copy
from collections import deque
import utils
from utils import object2
def deprecated(f):
printme = [True]
def g(*args, **kwargs):
if printme[0]:
print 'gof.graph.%s deprecated: April 29' % f.__name__
printme[0] = False
return f(*args, **kwargs)
return g
class Apply(object2):
class Apply(utils.object2):
"""
Note: it is illegal for an output element to have an owner != self
"""
......@@ -74,18 +63,13 @@ class Apply(object2):
raise TypeError("Cannot change the type of this input.", curr, new)
new_node = self.clone()
new_node.inputs = inputs
# new_node.outputs = []
# for output in self.outputs:
# new_output = copy(output)
# new_output.owner = new_node
# new_node.outputs.append(new_output)
return new_node
nin = property(lambda self: len(self.inputs))
nout = property(lambda self: len(self.outputs))
class Result(object2):
class Result(utils.object2):
#__slots__ = ['type', 'owner', 'index', 'name']
def __init__(self, type, owner = None, index = None, name = None):
self.type = type
......@@ -111,9 +95,6 @@ class Result(object2):
return "<?>::" + str(self.type)
def __repr__(self):
return str(self)
@deprecated
def __asresult__(self):
return self
def clone(self):
return self.__class__(self.type, None, None, self.name)
......@@ -137,7 +118,6 @@ class Constant(Value):
#__slots__ = ['data']
def __init__(self, type, data, name = None):
Value.__init__(self, type, data, name)
### self.indestructible = True
def equals(self, other):
# this does what __eq__ should do, but Result and Apply should always be hashable by id
return type(other) == type(self) and self.signature() == other.signature()
......@@ -148,32 +128,6 @@ class Constant(Value):
return self.name
return str(self.data) #+ "::" + str(self.type)
@deprecated
def as_result(x):
if isinstance(x, Result):
return x
# elif isinstance(x, Type):
# return Result(x, None, None)
elif hasattr(x, '__asresult__'):
r = x.__asresult__()
if not isinstance(r, Result):
raise TypeError("%s.__asresult__ must return a Result instance" % x, (x, r))
return r
else:
raise TypeError("Cannot wrap %s in a Result" % x)
@deprecated
def as_apply(x):
if isinstance(x, Apply):
return x
elif hasattr(x, '__asapply__'):
node = x.__asapply__()
if not isinstance(node, Apply):
raise TypeError("%s.__asapply__ must return an Apply instance" % x, (x, node))
return node
else:
raise TypeError("Cannot map %s to Apply" % x)
def stack_search(start, expand, mode='bfs', build_inv = False):
"""Search through L{Result}s, either breadth- or depth-first
@type start: deque
......@@ -234,45 +188,6 @@ def inputs(result_list):
return rval
# def results_and_orphans(r_in, r_out, except_unreachable_input=False):
# r_in_set = set(r_in)
# class Dummy(object): pass
# dummy = Dummy()
# dummy.inputs = r_out
# def expand_inputs(io):
# if io in r_in_set:
# return None
# try:
# return [io.owner] if io.owner != None else None
# except AttributeError:
# return io.inputs
# ops_and_results, dfsinv = stack_search(
# deque([dummy]),
# expand_inputs, 'dfs', True)
# if except_unreachable_input:
# for r in r_in:
# if r not in dfsinv:
# raise Exception(results_and_orphans.E_unreached)
# clients = stack_search(
# deque(r_in),
# lambda io: dfsinv.get(io,None), 'dfs')
# ops_to_compute = [o for o in clients if is_op(o) and o is not dummy]
# results = []
# for o in ops_to_compute:
# results.extend(o.inputs)
# results.extend(r_out)
# op_set = set(ops_to_compute)
# assert len(ops_to_compute) == len(op_set)
# orphans = [r for r in results \
# if (r.owner not in op_set) and (r not in r_in_set)]
# return results, orphans
# results_and_orphans.E_unreached = 'there were unreachable inputs'
def results_and_orphans(i, o):
"""
"""
......@@ -286,24 +201,6 @@ def results_and_orphans(i, o):
return results, orphans
#def results_and_orphans(i, o):
# results = set()
# orphans = set()
# def helper(r):
# if r in results:
# return
# results.add(r)
# if r.owner is None:
# if r not in i:
# orphans.add(r)
# else:
# for r2 in r.owner.inputs:
# helper(r2)
# for output in o:
# helper(output)
# return results, orphans
def ops(i, o):
"""
@type i: list
......@@ -569,122 +466,3 @@ def as_string(i, o,
return [describe(output) for output in o]
# class Graph:
# """
# Object-oriented wrapper for all the functions in this module.
# """
# def __init__(self, inputs, outputs):
# self.inputs = inputs
# self.outputs = outputs
# def ops(self):
# return ops(self.inputs, self.outputs)
# def values(self):
# return values(self.inputs, self.outputs)
# def orphans(self):
# return orphans(self.inputs, self.outputs)
# def io_toposort(self):
# return io_toposort(self.inputs, self.outputs)
# def toposort(self):
# return self.io_toposort()
# def clone(self):
# o = clone(self.inputs, self.outputs)
# return Graph(self.inputs, o)
# def __str__(self):
# return as_string(self.inputs, self.outputs)
if 0:
#these were the old implementations
# they were replaced out of a desire that graph search routines would not
# depend on the hash or id of any node, so that it would be deterministic
# and consistent between program executions.
@utils.deprecated('gof.graph', 'preserving only for review')
def _results_and_orphans(i, o, except_unreachable_input=False):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results = set()
i = set(i)
results.update(i)
incomplete_paths = []
reached = set()
def helper(r, path):
if r in i:
reached.add(r)
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [output])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
if except_unreachable_input and len(i) != len(reached):
raise Exception(results_and_orphans.E_unreached)
results.update(orphans)
return results, orphans
def _io_toposort(i, o, orderings = {}):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
@param orderings: {op: [requirements for op]} (defaults to {})
@rtype: ordered list
@return: L{Op}s that belong in the subgraph between i and o which
respects the following constraints:
- all inputs in i are assumed to be already computed
- the L{Op}s that compute an L{Op}'s inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
asdf = set([input.owner for input in op.inputs if input.owner and input.owner in all])
prereqs_d.setdefault(op, set()).update(asdf)
return utils.toposort(prereqs_d)
from utils import AbstractFunctionError
import utils
import graph
from graph import Value
import sys
import traceback
import sys, traceback
__excepthook = sys.excepthook
......@@ -67,7 +64,7 @@ class Linker:
print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
raise AbstractFunctionError()
raise utils.AbstractFunctionError()
def make_function(self, unpack_single = True, **kwargs):
"""
......@@ -151,7 +148,7 @@ def map_storage(env, order, input_storage, output_storage):
for node in order:
for r in node.inputs:
if r not in storage_map:
assert isinstance(r, Value)
assert isinstance(r, graph.Value)
storage_map[r] = [r.data]
for r in node.outputs:
storage_map.setdefault(r, [None])
......
"""
Contains the L{Op} class, which is the base interface for all operations
compatible with gof's graph manipulation routines.
"""
import utils
from utils import AbstractFunctionError, object2
from copy import copy
class Op(object2):
class Op(utils.object2):
default_output = None
"""@todo
......@@ -22,9 +17,19 @@ class Op(object2):
#############
def make_node(self, *inputs):
raise AbstractFunctionError()
"""
This function should return an Apply instance representing the
application of this Op on the provided inputs.
"""
raise utils.AbstractFunctionError()
def __call__(self, *inputs):
"""
Shortcut for:
self.make_node(*inputs).outputs[self.default_output] (if default_output is defined)
self.make_node(*inputs).outputs[0] (if only one output)
self.make_node(*inputs).outputs (if more than one output)
"""
node = self.make_node(*inputs)
if self.default_output is not None:
return node.outputs[self.default_output]
......@@ -44,6 +49,7 @@ class Op(object2):
Calculate the function on the inputs and put the results in the
output storage.
- node: Apply instance that contains the symbolic inputs and outputs
- inputs: sequence of inputs (immutable)
- output_storage: list of mutable 1-element lists (do not change
the length of these lists)
......@@ -53,7 +59,7 @@ class Op(object2):
by a previous call to impl and impl is free to reuse it as it
sees fit.
"""
raise AbstractFunctionError()
raise utils.AbstractFunctionError()
#####################
# C code generation #
......@@ -62,9 +68,8 @@ class Op(object2):
def c_code(self, node, name, inputs, outputs, sub):
"""Return the C implementation of an Op.
Returns templated C code that does the computation associated
to this L{Op}. You may assume that input validation and output
allocation have already been done.
Returns C code that does the computation associated to this L{Op},
given names for the inputs and outputs.
@param inputs: list of strings. There is a string for each input
of the function, and the string is the name of a C
......@@ -80,7 +85,7 @@ class Op(object2):
'fail').
"""
raise AbstractFunctionError('%s.c_code' \
raise utils.AbstractFunctionError('%s.c_code is not defined' \
% self.__class__.__name__)
def c_code_cleanup(self, node, name, inputs, outputs, sub):
......@@ -89,44 +94,33 @@ class Op(object2):
This is a convenient place to clean up things allocated by c_code().
"""
raise AbstractFunctionError()
raise utils.AbstractFunctionError()
def c_compile_args(self):
"""
Return a list of compile args recommended to manipulate this L{Op}.
"""
raise AbstractFunctionError()
raise utils.AbstractFunctionError()
def c_headers(self):
"""
Return a list of header files that must be included from C to manipulate
this L{Op}.
"""
raise AbstractFunctionError()
raise utils.AbstractFunctionError()
def c_libraries(self):
"""
Return a list of libraries to link against to manipulate this L{Op}.
"""
raise AbstractFunctionError()
raise utils.AbstractFunctionError()
def c_support_code(self):
"""
Return utility code for use by this L{Op}. It may refer to support code
defined for its input L{Result}s.
"""
raise AbstractFunctionError()
class PropertiedOp(Op):
raise utils.AbstractFunctionError()
def __eq__(self, other):
return type(self) == type(other) and self.__dict__ == other.__dict__
def __str__(self):
if hasattr(self, 'name') and self.name:
return self.name
else:
return "%s{%s}" % (self.__class__.__name__, ", ".join("%s=%s" % (k, v) for k, v in self.__dict__.items() if k != "name"))
"""
Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
"""
from op import Op
from graph import Constant
from type import Type
import graph
from env import InconsistencyError
import utils
import unify
import toolbox
import ext
class Optimizer:
......@@ -20,9 +22,8 @@ class Optimizer:
"""
Applies the optimization to the provided L{Env}. It may use all
the methods defined by the L{Env}. If the L{Optimizer} needs
to use a certain tool, such as an L{InstanceFinder}, it should
set the L{__env_require__} field to a list of what needs to be
registered with the L{Env}.
to use a certain tool, such as an L{InstanceFinder}, it can do
so in its L{add_requirements} method.
"""
pass
......@@ -36,9 +37,19 @@ class Optimizer:
self.apply(env)
def __call__(self, env):
"""
Same as self.optimize(env)
"""
return self.optimize(env)
def add_requirements(self, env):
"""
Add features to the env that are required to apply the optimization.
For example:
env.extend(History())
env.extend(MyFeature())
etc.
"""
pass
......@@ -79,7 +90,7 @@ class LocalOptimizer(Optimizer):
following two methods:
- candidates(env) -> returns a set of ops that can be
optimized
- apply_on_op(env, op) -> for each op in candidates,
- apply_on_node(env, node) -> for each node in candidates,
this function will be called to perform the actual
optimization.
"""
......@@ -102,7 +113,7 @@ class LocalOptimizer(Optimizer):
Calls self.apply_on_op(env, op) for each op in self.candidates(env).
"""
for node in self.candidates(env):
if env.has_node(node):
if node in env.nodes:
self.apply_on_node(env, node)
......@@ -122,7 +133,7 @@ class OpSpecificOptimizer(LocalOptimizer):
def candidates(self, env):
"""
Returns all instances of L{self.op}.
Returns all nodes that have L{self.op} in their op field.
"""
return env.get_nodes(self.op)
......@@ -131,13 +142,22 @@ class OpSpecificOptimizer(LocalOptimizer):
class OpSubOptimizer(Optimizer):
"""
Replaces all L{Op}s of a certain type by L{Op}s of another type that
take the same inputs as what they are replacing.
Replaces all applications of a certain op by the application of
another op that take the same inputs as what they are replacing.
e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
OpSubOptimizer requires the following features:
- NodeFinder
- ReplaceValidate
"""
def add_requirements(self, env):
"""
Requires the following features:
- NodeFinder
- ReplaceValidate
"""
try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
......@@ -145,9 +165,12 @@ class OpSubOptimizer(Optimizer):
def __init__(self, op1, op2, failure_callback = None):
"""
op1 and op2 must both be Op subclasses, they must both take
the same number of inputs and they must both have the same
number of outputs.
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (node, replacement, exception)
"""
self.op1 = op1
self.op2 = op2
......@@ -155,12 +178,8 @@ class OpSubOptimizer(Optimizer):
def apply(self, env):
"""
Replaces all occurrences of self.op1 by instances of self.op2
Replaces all applications of self.op1 by applications of self.op2
with the same inputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (op1_instance, replacement, exception)
"""
candidates = env.get_nodes(self.op1)
......@@ -173,7 +192,6 @@ class OpSubOptimizer(Optimizer):
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(node, repl, e)
pass
def str(self):
return "%s -> %s" % (self.op1, self.op2)
......@@ -183,7 +201,7 @@ class OpSubOptimizer(Optimizer):
class OpRemover(Optimizer):
"""
@todo untested
Removes all ops of a certain type by transferring each of its
Removes all applications of an op by transferring each of its
outputs to the corresponding input.
"""
......@@ -195,21 +213,19 @@ class OpRemover(Optimizer):
def __init__(self, op, failure_callback = None):
"""
opclass is the class of the ops to remove. It must take as
many inputs as outputs.
Applications of the op must have as many inputs as outputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (node, exception)
"""
self.op = op
self.failure_callback = failure_callback
def apply(self, env):
"""
Removes all occurrences of self.opclass.
If self.failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (opclass_instance, exception)
Removes all applications of self.op.
"""
candidates = env.get_nodes(self.op)
for node in candidates:
......@@ -231,17 +247,17 @@ class PatternOptimizer(OpSpecificOptimizer):
"""
@todo update
Replaces all occurrences of the input pattern by the output pattern::
Replaces all occurrences of the input pattern by the output pattern:
input_pattern ::= (OpClass, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>)
sub_pattern ::= input_pattern
sub_pattern ::= string
sub_pattern ::= a Result r such that r.constant is True
sub_pattern ::= a Constant instance
constraint ::= lambda env, expr: additional matching condition
output_pattern ::= (OpClass, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string
Each string in the input pattern is a variable that will be set to
......@@ -253,8 +269,8 @@ class PatternOptimizer(OpSpecificOptimizer):
pattern can.
If you put a constant result in the input pattern, there will be a
match iff a constant result with the same value is found in its
place.
match iff a constant result with the same value and the same type
is found in its place.
You can add a constraint to the match by using the dict(...) form
described above with a 'constraint' key. The constraint must be a
......@@ -263,16 +279,27 @@ class PatternOptimizer(OpSpecificOptimizer):
arbitrary criterion.
Examples:
PatternOptimizer((Add, 'x', 'y'), (Add, 'y', 'x'))
PatternOptimizer((Multiply, 'x', 'x'), (Square, 'x'))
PatternOptimizer((Subtract, (Add, 'x', 'y'), 'y'), 'x')
PatternOptimizer((Power, 'x', Double(2.0, constant = True)), (Square, 'x'))
PatternOptimizer((Boggle, {'pattern': 'x',
'constraint': lambda env, expr: expr.owner.scrabble == True}),
(Scrabble, 'x'))
PatternOptimizer((add, 'x', 'y'), (add, 'y', 'x'))
PatternOptimizer((multiply, 'x', 'x'), (square, 'x'))
PatternOptimizer((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternOptimizer((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternOptimizer((boggle, {'pattern': 'x',
'constraint': lambda env, expr: expr.type == scrabble}),
(scrabble, 'x'))
"""
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None):
"""
Creates a PatternOptimizer that replaces occurrences of
in_pattern by occurrences of out_pattern.
If failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (result_to_replace, replacement, exception).
If allow_multiple_clients is False, he pattern matching will
fail if one of the subpatterns has more than one client.
"""
self.in_pattern = in_pattern
self.out_pattern = out_pattern
if isinstance(in_pattern, (list, tuple)):
......@@ -287,15 +314,8 @@ class PatternOptimizer(OpSpecificOptimizer):
def apply_on_node(self, env, node):
"""
Checks if the graph from op corresponds to in_pattern. If it does,
Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
If self.failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (results_to_replace, replacement, exception).
If self.allow_multiple_clients is False, he pattern matching will fail
if one of the subpatterns has more than one client.
"""
def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)):
......@@ -323,7 +343,7 @@ class PatternOptimizer(OpSpecificOptimizer):
return False
else:
u = u.merge(expr, v)
elif isinstance(pattern, Constant) and isinstance(expr, Constant) and pattern.equals(expr):
elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr):
return u
else:
return False
......@@ -363,28 +383,6 @@ class PatternOptimizer(OpSpecificOptimizer):
# class ConstantFinder(Optimizer):
# """
# Sets as constant every orphan that is not destroyed.
# """
# def apply(self, env):
# if env.has_feature(ext.DestroyHandler(env)):
# for r in env.orphans():
# if not env.destroyers(r):
# r.indestructible = True
# r.constant = True
# # for r in env.inputs:
# # if not env.destroyers(r):
# # r.indestructible = True
# else:
# for r in env.orphans():
# r.indestructible = True
# r.constant = True
# # for r in env.inputs:
# # r.indestructible = True
import graph
class _metadict:
# dict that accepts unhashable keys
......@@ -438,15 +436,14 @@ class MergeOptimizer(Optimizer):
def apply(self, env):
cid = _metadict() #result -> result.desc() (for constants)
inv_cid = _metadict() #desc -> result (for constants)
for i, r in enumerate([r for r in env.results if isinstance(r, Constant)]): #env.orphans.union(env.inputs)):
#if isinstance(r, Constant):
sig = r.signature()
other_r = inv_cid.get(sig, None)
if other_r is not None:
env.replace(r, other_r)
else:
cid[r] = sig
inv_cid[sig] = r
for i, r in enumerate([r for r in env.results if isinstance(r, graph.Constant)]):
sig = r.signature()
other_r = inv_cid.get(sig, None)
if other_r is not None:
env.replace(r, other_r)
else:
cid[r] = sig
inv_cid[sig] = r
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer cid like the other Results
cid.clear()
......@@ -483,123 +480,3 @@ def MergeOptMerge(opt):
merger = MergeOptimizer()
return SeqOptimizer([merger, opt, merger])
### THE FOLLOWING OPTIMIZERS ARE NEITHER USED NOR TESTED BUT PROBABLY WORK AND COULD BE USEFUL ###
# class MultiOptimizer(Optimizer):
# def __init__(self, **opts):
# self._opts = []
# self.ord = {}
# self.name_to_opt = {}
# self.up_to_date = True
# for name, opt in opts:
# self.register(name, opt, after = [], before = [])
# def register(self, name, opt, **relative):
# self.name_to_opt[name] = opt
# after = relative.get('after', [])
# if not isinstance(after, (list, tuple)):
# after = [after]
# before = relative.get('before', [])
# if not isinstance(before, (list, tuple)):
# before = [before]
# self.up_to_date = False
# if name in self.ord:
# raise Exception("Cannot redefine optimization: '%s'" % name)
# self.ord[name] = set(after)
# for postreq in before:
# self.ord.setdefault(postreq, set()).add(name)
# def get_opts(self):
# if not self.up_to_date:
# self.refresh()
# return self._opts
# def refresh(self):
# self._opts = [self.name_to_opt[name] for name in utils.toposort(self.ord)]
# self.up_to_date = True
# def apply(self, env):
# for opt in self.opts:
# opt.apply(env)
# opts = property(get_opts)
# class TaggedMultiOptimizer(MultiOptimizer):
# def __init__(self, **opts):
# self.tags = {}
# MultiOptimizer.__init__(self, **opts)
# def register(self, name, opt, tags = [], **relative):
# tags = set(tags)
# tags.add(name)
# self.tags[opt] = tags
# MultiOptimizer.register(self, name, opt, **relative)
# def filter(self, whitelist, blacklist):
# return [opt for opt in self.opts
# if self.tags[opt].intersection(whitelist)
# and not self.tags[opt].intersection(blacklist)]
# def whitelist(self, *tags):
# return [opt for opt in self.opts if self.tags[opt].intersection(tags)]
# def blacklist(self, *tags):
# return [opt for opt in self.opts if not self.tags[opt].intersection(tags)]
# class TagFilterMultiOptimizer(Optimizer):
# def __init__(self, all, whitelist = None, blacklist = None):
# self.all = all
# if whitelist is not None:
# self.whitelist = set(whitelist)
# else:
# self.whitelist = None
# if blacklist is not None:
# self.blacklist = set(blacklist)
# else:
# self.blacklist = set()
# def use_whitelist(self, use = True):
# if self.whitelist is None and use:
# self.whitelist = set()
# def allow(self, *tags):
# if self.whitelist is not None:
# self.whitelist.update(tags)
# self.blacklist.difference_update(tags)
# def deny(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.update(tags)
# def dont_care(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.difference_update(tags)
# def opts(self):
# if self.whitelist is not None:
# return self.all.filter(self.whitelist, self.blacklist)
# else:
# return self.all.blacklist(*[tag for tag in self.blacklist])
# def apply(self, env):
# for opt in self.opts():
# opt.apply(env)
from random import shuffle
import utils
from functools import partial
import graph
......@@ -14,51 +12,6 @@ class Bookkeeper:
def on_detach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs):
self.on_prune(env, node)
# class Toposorter:
# def on_attach(self, env):
# if hasattr(env, 'toposort'):
# raise Exception("Toposorter feature is already present or in conflict with another plugin.")
# env.toposort = partial(self.toposort, env)
# def on_detach(self, env):
# del env.toposort
# def toposort(self, env):
# ords = {}
# for feature in env._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings(env).items():
# ords.setdefault(op, set()).update(prereqs)
# order = graph.io_toposort(env.inputs, env.outputs, ords)
# return order
# def supplemental_orderings(self):
# """
# Returns a dictionary of {op: set(prerequisites)} that must
# be satisfied in addition to the order defined by the structure
# of the graph (returns orderings that not related to input/output
# relationships).
# """
# ords = {}
# for feature in self._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings().items():
# ords.setdefault(op, set()).update(prereqs)
# return ords
# def toposort(self):
# """
# Returns a list of nodes in the order that they must be executed
# in order to preserve the semantics of the graph and respect
# the constraints put forward by the listeners.
# """
# ords = self.supplemental_orderings()
# order = graph.io_toposort(self.inputs, self.outputs, ords)
# return order
class History:
......@@ -223,151 +176,3 @@ class PrintListener(object):
# class EquivTool(dict):
# def __init__(self, env):
# self.env = env
# def on_rewire(self, clients, r, new_r):
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
# def publish(self):
# self.env.equiv = self
# self.env.set_equiv = self.set_equiv
# def unpublish(self):
# del self.env.equiv
# del self.env.set_equiv
# def set_equiv(self, d):
# self.update(d)
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
# def all_bases(self, cls):
# return utils.all_bases(cls, lambda cls: cls is not object)
# def on_import(self, op):
# for base in self.all_bases(op.__class__):
# self.setdefault(base, set()).add(op)
# def on_prune(self, op):
# for base in self.all_bases(op.__class__):
# self[base].remove(op)
# if not self[base]:
# del self[base]
# def __query__(self, cls):
# all = [x for x in self.get(cls, [])]
# shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, cls):
# return self.__query__(cls)
# def publish(self):
# self.env.get_instances_of = self.query
# class DescFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
# def on_import(self, op):
# self.setdefault(op.desc(), set()).add(op)
# def on_prune(self, op):
# desc = op.desc()
# self[desc].remove(op)
# if not self[desc]:
# del self[desc]
# def __query__(self, desc):
# all = [x for x in self.get(desc, [])]
# shuffle(all) # this helps for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, desc):
# return self.__query__(desc)
# def publish(self):
# self.env.get_from_desc = self.query
### UNUSED AND UNTESTED ###
# class ChangeListener(Listener):
# def __init__(self, env):
# self.change = False
# def on_import(self, op):
# self.change = True
# def on_prune(self, op):
# self.change = True
# def on_rewire(self, clients, r, new_r):
# self.change = True
# def __call__(self, value = "get"):
# if value == "get":
# return self.change
# else:
# self.change = value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论