提交 14090d83 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

documentation and cleanup

上级 abd860d7
...@@ -4,11 +4,15 @@ import unittest ...@@ -4,11 +4,15 @@ import unittest
from link import PerformLinker, Profiler from link import PerformLinker, Profiler
from cc import * from cc import *
from type import Type from type import Type
from graph import Result, as_result, Apply, Constant from graph import Result, Apply, Constant
from op import Op from op import Op
import env import env
import toolbox import toolbox
def as_result(x):
assert isinstance(x, Result)
return x
class TDouble(Type): class TDouble(Type):
def filter(self, data): def filter(self, data):
return float(data) return float(data)
......
...@@ -3,18 +3,20 @@ import unittest ...@@ -3,18 +3,20 @@ import unittest
from type import Type from type import Type
import graph import graph
from graph import Result, as_result, Apply from graph import Result, Apply
from op import Op from op import Op
from opt import PatternOptimizer, OpSubOptimizer from opt import PatternOptimizer, OpSubOptimizer
from ext import * from ext import *
from env import Env, InconsistencyError from env import Env, InconsistencyError
#from toolbox import EquivTool
from toolbox import ReplaceValidate from toolbox import ReplaceValidate
from copy import copy from copy import copy
#from _test_result import MyResult
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type): class MyType(Type):
......
...@@ -15,6 +15,10 @@ else: ...@@ -15,6 +15,10 @@ else:
realtestcase = unittest.TestCase realtestcase = unittest.TestCase
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type): class MyType(Type):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import unittest import unittest
import graph import graph
from graph import Result, as_result, Apply, Constant from graph import Result, Apply, Constant
from type import Type from type import Type
from op import Op from op import Op
import env import env
...@@ -14,6 +14,10 @@ from link import * ...@@ -14,6 +14,10 @@ from link import *
#from _test_result import Double #from _test_result import Double
def as_result(x):
assert isinstance(x, Result)
return x
class TDouble(Type): class TDouble(Type):
def filter(self, data): def filter(self, data):
return float(data) return float(data)
......
...@@ -3,9 +3,11 @@ import unittest ...@@ -3,9 +3,11 @@ import unittest
from copy import copy from copy import copy
from op import * from op import *
from type import Type, Generic from type import Type, Generic
from graph import Apply, as_result from graph import Apply, Result
#from result import Result def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type): class MyType(Type):
......
import unittest import unittest
from graph import Result, as_result, Apply, Constant from type import Type
from graph import Result, Apply, Constant
from op import Op from op import Op
from opt import * from opt import *
from env import Env from env import Env
from toolbox import * from toolbox import *
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type): class MyType(Type):
......
import unittest import unittest
from graph import Result, as_result, Apply from graph import Result, Apply
from type import Type from type import Type
from op import Op from op import Op
#from opt import PatternOptimizer, OpSubOptimizer
from env import Env, InconsistencyError from env import Env, InconsistencyError
from toolbox import * from toolbox import *
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type): class MyType(Type):
def __init__(self, name): def __init__(self, name):
......
"""
Defines Linkers that deal with C implementations.
"""
import graph
from graph import Constant, Value # Python imports
from link import Linker, LocalLinker, raise_with_op, Filter, map_storage, PerformLinker
from copy import copy from copy import copy
from utils import AbstractFunctionError
from env import Env
import md5 import md5
import sys import re
import os import os, sys, platform
import platform
# weave import
from scipy import weave from scipy import weave
# gof imports
import cutils import cutils
from env import Env
import graph
import link
import utils import utils
import traceback
import re
def compile_dir(): def compile_dir():
...@@ -192,7 +196,7 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -192,7 +196,7 @@ def struct_gen(args, struct_builders, blocks, sub):
return %(failure_var)s; return %(failure_var)s;
""" % sub """ % sub
sub = copy(sub) sub = dict(sub)
sub.update(locals()) sub.update(locals())
# TODO: add some error checking to make sure storage_<x> are 1-element lists # TODO: add some error checking to make sure storage_<x> are 1-element lists
...@@ -309,7 +313,7 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub): ...@@ -309,7 +313,7 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
name = "V%i" % id name = "V%i" % id
symbol_table[result] = name symbol_table[result] = name
sub = copy(sub) sub = dict(sub)
# sub['name'] = name # sub['name'] = name
sub['id'] = id sub['id'] = id
sub['fail'] = failure_code(sub) sub['fail'] = failure_code(sub)
...@@ -323,13 +327,16 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub): ...@@ -323,13 +327,16 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
return struct_builder, block return struct_builder, block
class CLinker(Linker): class CLinker(link.Linker):
""" """
Creates C code for an env or an Op instance, compiles it and returns
callables through make_thunk and make_function that make use of the Creates C code for an env, compiles it and returns callables
compiled code. through make_thunk and make_function that make use of the compiled
code.
It can take an env or an Op as input. no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
""" """
def __init__(self, env, no_recycling = []): def __init__(self, env, no_recycling = []):
...@@ -346,7 +353,7 @@ class CLinker(Linker): ...@@ -346,7 +353,7 @@ class CLinker(Linker):
self.outputs = env.outputs self.outputs = env.outputs
self.results = graph.results(self.inputs, self.outputs) # list(env.results) self.results = graph.results(self.inputs, self.outputs) # list(env.results)
# The orphans field is listified to ensure a consistent order. # The orphans field is listified to ensure a consistent order.
self.orphans = list(r for r in self.results if isinstance(r, Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs)) self.orphans = list(r for r in self.results if isinstance(r, graph.Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs))
self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans)) self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans))
self.node_order = env.toposort() self.node_order = env.toposort()
...@@ -404,15 +411,15 @@ class CLinker(Linker): ...@@ -404,15 +411,15 @@ class CLinker(Linker):
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_extract, get_c_cleanup]] [get_c_declare, get_c_extract, get_c_cleanup]]
elif result in self.orphans: elif result in self.orphans:
if not isinstance(result, Value): if not isinstance(result, graph.Value):
raise TypeError("All orphans to CLinker must be Value instances.", result) raise TypeError("All orphans to CLinker must be Value instances.", result)
if isinstance(result, Constant): if isinstance(result, graph.Constant):
try: try:
symbol[result] = "(" + result.type.c_literal(result.data) + ")" symbol[result] = "(" + result.type.c_literal(result.data) + ")"
consts.append(result) consts.append(result)
self.orphans.remove(result) self.orphans.remove(result)
continue continue
except (AbstractFunctionError, NotImplementedError): except (utils.AbstractFunctionError, NotImplementedError):
pass pass
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same # orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same
policy = [[get_c_declare, get_c_extract, get_c_cleanup], policy = [[get_c_declare, get_c_extract, get_c_cleanup],
...@@ -475,11 +482,11 @@ class CLinker(Linker): ...@@ -475,11 +482,11 @@ class CLinker(Linker):
op = node.op op = node.op
try: behavior = op.c_code(node, name, isyms, osyms, sub) try: behavior = op.c_code(node, name, isyms, osyms, sub)
except AbstractFunctionError: except utils.AbstractFunctionError:
raise NotImplementedError("%s cannot produce C code" % op) raise NotImplementedError("%s cannot produce C code" % op)
try: cleanup = op.c_code_cleanup(node, name, isyms, osyms, sub) try: cleanup = op.c_code_cleanup(node, name, isyms, osyms, sub)
except AbstractFunctionError: except utils.AbstractFunctionError:
cleanup = "" cleanup = ""
blocks.append(CodeBlock("", behavior, cleanup, sub)) blocks.append(CodeBlock("", behavior, cleanup, sub))
...@@ -539,7 +546,7 @@ class CLinker(Linker): ...@@ -539,7 +546,7 @@ class CLinker(Linker):
ret = [] ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret.append(x.c_support_code()) try: ret.append(x.c_support_code())
except AbstractFunctionError: pass except utils.AbstractFunctionError: pass
return ret return ret
def compile_args(self): def compile_args(self):
...@@ -552,7 +559,7 @@ class CLinker(Linker): ...@@ -552,7 +559,7 @@ class CLinker(Linker):
ret = [] ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_compile_args() try: ret += x.c_compile_args()
except AbstractFunctionError: pass except utils.AbstractFunctionError: pass
return ret return ret
def headers(self): def headers(self):
...@@ -565,7 +572,7 @@ class CLinker(Linker): ...@@ -565,7 +572,7 @@ class CLinker(Linker):
ret = [] ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_headers() try: ret += x.c_headers()
except AbstractFunctionError: pass except utils.AbstractFunctionError: pass
return ret return ret
def libraries(self): def libraries(self):
...@@ -578,25 +585,23 @@ class CLinker(Linker): ...@@ -578,25 +585,23 @@ class CLinker(Linker):
ret = [] ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_libraries() try: ret += x.c_libraries()
except AbstractFunctionError: pass except utils.AbstractFunctionError: pass
return ret return ret
def __compile__(self, input_storage = None, output_storage = None): def __compile__(self, input_storage = None, output_storage = None):
""" """
@todo update Compiles this linker's env.
Compiles this linker's env. If inplace is True, it will use the @type input_storage: list or None
Results contained in the env, if it is False it will copy the @param input_storage: list of lists of length 1. In order to use
input and output Results. the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated.
Returns: thunk, in_results, out_results, error_storage @param output_storage: list of lists of length 1. The thunk returned
by __compile__ will put the results of the computation in these
lists. If None, storage will be allocated.
Returns: thunk, input_storage, output_storage, error_storage
""" """
# if inplace:
# in_results = self.inputs
# out_results = self.outputs
# else:
# in_results = [copy(input) for input in self.inputs]
# out_results = [copy(output) for output in self.outputs]
error_storage = [None, None, None] error_storage = [None, None, None]
if input_storage is None: if input_storage is None:
input_storage = tuple([None] for result in self.inputs) input_storage = tuple([None] for result in self.inputs)
...@@ -612,13 +617,34 @@ class CLinker(Linker): ...@@ -612,13 +617,34 @@ class CLinker(Linker):
thunk = self.cthunk_factory(error_storage, thunk = self.cthunk_factory(error_storage,
input_storage, input_storage,
output_storage) output_storage)
return thunk, [Filter(input.type, storage) for input, storage in zip(self.env.inputs, input_storage)], \ return thunk, \
[Filter(output.type, storage, True) for output, storage in zip(self.env.outputs, output_storage)], \ [link.Filter(input.type, storage) for input, storage in zip(self.env.inputs, input_storage)], \
[link.Filter(output.type, storage, True) for output, storage in zip(self.env.outputs, output_storage)], \
error_storage error_storage
# return thunk, [Filter(x) for x in input_storage], [Filter(x) for x in output_storage], error_storage
def make_thunk(self, input_storage = None, output_storage = None): def make_thunk(self, input_storage = None, output_storage = None):
"""
Compiles this linker's env and returns a function to perform the
computations, as well as lists of storage cells for both the
inputs and outputs.
@type input_storage: list or None
@param input_storage: list of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated.
@param output_storage: list of lists of length 1. The thunk returned
by __compile__ will put the results of the computation in these
lists. If None, storage will be allocated.
Returns: thunk, input_storage, output_storage
The return values can be used as follows:
f, istor, ostor = clinker.make_thunk()
istor[0].data = first_input
istor[1].data = second_input
f()
first_output = ostor[0].data
"""
cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage) cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage)
def execute(): def execute():
failure = cutils.run_cthunk(cthunk) failure = cutils.run_cthunk(cthunk)
...@@ -729,7 +755,7 @@ class CLinker(Linker): ...@@ -729,7 +755,7 @@ class CLinker(Linker):
class OpWiseCLinker(LocalLinker): class OpWiseCLinker(link.LocalLinker):
""" """
Uses CLinker on the individual Ops that comprise an env and loops Uses CLinker on the individual Ops that comprise an env and loops
over them in Python. The result is slower than a compiled version of over them in Python. The result is slower than a compiled version of
...@@ -739,6 +765,10 @@ class OpWiseCLinker(LocalLinker): ...@@ -739,6 +765,10 @@ class OpWiseCLinker(LocalLinker):
If fallback_on_perform is True, OpWiseCLinker will use an op's If fallback_on_perform is True, OpWiseCLinker will use an op's
perform method if no C version can be generated. perform method if no C version can be generated.
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
""" """
def __init__(self, env, fallback_on_perform = True, no_recycling = []): def __init__(self, env, fallback_on_perform = True, no_recycling = []):
...@@ -756,7 +786,7 @@ class OpWiseCLinker(LocalLinker): ...@@ -756,7 +786,7 @@ class OpWiseCLinker(LocalLinker):
order = env.toposort() order = env.toposort()
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(env, order, input_storage, output_storage) input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage)
thunks = [] thunks = []
for node in order: for node in order:
...@@ -772,7 +802,7 @@ class OpWiseCLinker(LocalLinker): ...@@ -772,7 +802,7 @@ class OpWiseCLinker(LocalLinker):
thunk.inputs = node_input_storage thunk.inputs = node_input_storage
thunk.outputs = node_output_storage thunk.outputs = node_output_storage
thunks.append(thunk) thunks.append(thunk)
except (NotImplementedError, AbstractFunctionError): except (NotImplementedError, utils.AbstractFunctionError):
if self.fallback_on_perform: if self.fallback_on_perform:
p = node.op.perform p = node.op.perform
thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o) thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o)
...@@ -791,27 +821,8 @@ class OpWiseCLinker(LocalLinker): ...@@ -791,27 +821,8 @@ class OpWiseCLinker(LocalLinker):
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler) f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
# if profiler is None: return f, [link.Filter(input.type, storage) for input, storage in zip(env.inputs, input_storage)], \
# def f(): [link.Filter(output.type, storage, True) for output, storage in zip(env.outputs, output_storage)], \
# for x in no_recycling:
# x[0] = None
# try:
# for thunk, node in zip(thunks, order):
# thunk()
# except:
# raise_with_op(node)
# else:
# def f():
# for x in no_recycling:
# x[0] = None
# def g():
# for thunk, node in zip(thunks, order):
# profiler.profile_op(thunk, node)
# profiler.profile_env(g, env)
# f.profiler = profiler
return f, [Filter(input.type, storage) for input, storage in zip(env.inputs, input_storage)], \
[Filter(output.type, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks, order thunks, order
...@@ -825,7 +836,7 @@ def _default_checker(x, y): ...@@ -825,7 +836,7 @@ def _default_checker(x, y):
if x[0] != y[0]: if x[0] != y[0]:
raise Exception("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]}) raise Exception("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]})
class DualLinker(Linker): class DualLinker(link.Linker):
""" """
Runs the env in parallel using PerformLinker and CLinker. Runs the env in parallel using PerformLinker and CLinker.
...@@ -841,13 +852,13 @@ class DualLinker(Linker): ...@@ -841,13 +852,13 @@ class DualLinker(Linker):
""" """
Initialize a DualLinker. Initialize a DualLinker.
The checker argument must be a function that takes two Result The checker argument must be a function that takes two lists
instances. The first one passed will be the output computed by of length 1. The first one passed will contain the output
PerformLinker and the second one the output computed by computed by PerformLinker and the second one the output
OpWiseCLinker. The checker should compare the data fields of computed by OpWiseCLinker. The checker should compare the data
the two results to see if they match. By default, DualLinker fields of the two results to see if they match. By default,
uses ==. A custom checker can be provided to compare up to a DualLinker uses ==. A custom checker can be provided to
certain error tolerance. compare up to a certain error tolerance.
If a mismatch occurs, the checker should raise an exception to If a mismatch occurs, the checker should raise an exception to
halt the computation. If it does not, the computation will halt the computation. If it does not, the computation will
...@@ -855,35 +866,22 @@ class DualLinker(Linker): ...@@ -855,35 +866,22 @@ class DualLinker(Linker):
the problem by fiddling with the data, but it should be the problem by fiddling with the data, but it should be
careful not to share data between the two outputs (or inplace careful not to share data between the two outputs (or inplace
operations that use them will interfere). operations that use them will interfere).
no_recycling can contain a list of Results that belong to the env.
If a Result is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
""" """
self.env = env self.env = env
self.checker = checker self.checker = checker
self.no_recycling = no_recycling self.no_recycling = no_recycling
def make_thunk(self, **kwargs): def make_thunk(self, **kwargs):
# if inplace:
# env1 = self.env
# else:
# env1 = self.env.clone(True)
# env2, equiv = env1.clone_get_equiv(True)
# op_order_1 = env1.toposort()
# op_order_2 = [equiv[op.outputs[0]].owner for op in op_order_1] # we need to have the exact same order so we can compare each step
# def c_make_thunk(op):
# try:
# return CLinker(op).make_thunk(True)[0]
# except AbstractFunctionError:
# return op.perform
# thunks1 = [op.perform for op in op_order_1]
# thunks2 = [c_make_thunk(op) for op in op_order_2]
env = self.env env = self.env
no_recycling = self.no_recycling no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = PerformLinker(env, no_recycling = no_recycling).make_all(**kwargs) _f, i1, o1, thunks1, order1 = link.PerformLinker(env, no_recycling = no_recycling).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = OpWiseCLinker(env, no_recycling = no_recycling).make_all(**kwargs) _f, i2, o2, thunks2, order2 = OpWiseCLinker(env, no_recycling = no_recycling).make_all(**kwargs)
def f(): def f():
for input1, input2 in zip(i1, i2): for input1, input2 in zip(i1, i2):
...@@ -903,15 +901,7 @@ class DualLinker(Linker): ...@@ -903,15 +901,7 @@ class DualLinker(Linker):
for output1, output2 in zip(thunk1.outputs, thunk2.outputs): for output1, output2 in zip(thunk1.outputs, thunk2.outputs):
self.checker(output1, output2) self.checker(output1, output2)
except: except:
raise_with_op(node1) link.raise_with_op(node1)
# exc_type, exc_value, exc_trace = sys.exc_info()
# try:
# trace = op1.trace
# except AttributeError:
# trace = ()
# exc_value.__thunk_trace__ = trace
# exc_value.args = exc_value.args + (op1, )
# raise exc_type, exc_value, exc_trace
return f, i1, o1 return f, i1, o1
......
from copy import copy from copy import copy
import graph import graph
##from features import Listener, Orderings, Constraint, Tool, uniq_features
import utils import utils
from utils import AbstractFunctionError
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
This exception is raised by Env whenever one of the listeners marks This exception should be thrown by listeners to Env when the
the graph as inconsistent. graph's state is invalid.
""" """
pass pass
class Env(object): #(graph.Graph): class Env(utils.object2):
""" """
An Env represents a subgraph bound by a set of input results and a An Env represents a subgraph bound by a set of input results and a
set of output results. An op is in the subgraph iff it depends on set of output results. The inputs list should contain all the inputs
the value of some of the Env's inputs _and_ some of the Env's on which the outputs depend. Results of type Value or Constant are
outputs depend on it. A result is in the subgraph iff it is an not counted as inputs.
input or an output of an op that is in the subgraph.
The Env supports the replace operation which allows to replace a The Env supports the replace operation which allows to replace a
result in the subgraph by another, e.g. replace (x + x).out by (2 result in the subgraph by another, e.g. replace (x + x).out by (2
* x).out. This is the basis for optimization in theano. * x).out. This is the basis for optimization in theano.
It can also be "extended" using env.extend(some_object). See the
toolbox and ext modules for common extensions.
""" """
### Special ### ### Special ###
...@@ -65,12 +64,14 @@ class Env(object): #(graph.Graph): ...@@ -65,12 +64,14 @@ class Env(object): #(graph.Graph):
### Setup a Result ### ### Setup a Result ###
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this env
if hasattr(r, 'env') and r.env is not None and r.env is not self: if hasattr(r, 'env') and r.env is not None and r.env is not self:
raise Exception("%s is already owned by another env" % r) raise Exception("%s is already owned by another env" % r)
r.env = self r.env = self
r.clients = [] r.clients = []
def __setup_node__(self, node): def __setup_node__(self, node):
# sets up node so it belongs to this env
if hasattr(node, 'env') and node.env is not self: if hasattr(node, 'env') and node.env is not self:
raise Exception("%s is already owned by another env" % node) raise Exception("%s is already owned by another env" % node)
node.env = self node.env = self
...@@ -80,28 +81,27 @@ class Env(object): #(graph.Graph): ...@@ -80,28 +81,27 @@ class Env(object): #(graph.Graph):
### clients ### ### clients ###
def clients(self, r): def clients(self, r):
"Set of all the (op, i) pairs such that op.inputs[i] is r." "Set of all the (node, i) pairs such that node.inputs[i] is r."
return r.clients return r.clients
def __add_clients__(self, r, all): def __add_clients__(self, r, new_clients):
""" """
r -> result r -> result
all -> list of (op, i) pairs representing who r is an input of. new_clients -> list of (node, i) pairs such that node.inputs[i] is r.
Updates the list of clients of r with all. Updates the list of clients of r with new_clients.
""" """
r.clients += all r.clients += new_clients
def __remove_clients__(self, r, all, prune = True): def __remove_clients__(self, r, clients_to_remove, prune = True):
""" """
r -> result r -> result
all -> list of (op, i) pairs representing who r is an input of. clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
Removes all from the clients list of r. Removes all from the clients list of r.
""" """
for entry in all: for entry in clients_to_remove:
r.clients.remove(entry) r.clients.remove(entry)
# remove from orphans?
if not r.clients: if not r.clients:
if prune: if prune:
self.__prune_r__([r]) self.__prune_r__([r])
...@@ -188,6 +188,15 @@ class Env(object): #(graph.Graph): ...@@ -188,6 +188,15 @@ class Env(object): #(graph.Graph):
### change input ### ### change input ###
def change_input(self, node, i, new_r): def change_input(self, node, i, new_r):
"""
Changes node.inputs[i] to new_r.
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(env, node, i, old_r, new_r)
"""
if node == 'output': if node == 'output':
r = self.outputs[i] r = self.outputs[i]
if not r.type == new_r.type: if not r.type == new_r.type:
...@@ -214,10 +223,7 @@ class Env(object): #(graph.Graph): ...@@ -214,10 +223,7 @@ class Env(object): #(graph.Graph):
def replace(self, r, new_r): def replace(self, r, new_r):
""" """
This is the main interface to manipulate the subgraph in Env. This is the main interface to manipulate the subgraph in Env.
For every op that uses r as input, makes it use new_r instead. For every node that uses r as input, makes it use new_r instead.
This may raise an error if the new result violates type
constraints for one of the target nodes. In that case, no
changes are made.
""" """
if r.env is not self: if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r) raise Exception("Cannot replace %s because it does not belong to this Env" % r)
...@@ -238,11 +244,32 @@ class Env(object): #(graph.Graph): ...@@ -238,11 +244,32 @@ class Env(object): #(graph.Graph):
def extend(self, feature): def extend(self, feature):
""" """
@todo out of date Adds a feature to this env. The feature may define one
Adds an instance of the feature_class to this env's supported or more of the following methods:
features. If do_import is True and feature_class is a subclass
of Listener, its on_import method will be called on all the Nodes - feature.on_attach(env)
already in the env. Called by extend. The feature has great freedom in what
it can do with the env: it may, for example, add methods
to it dynicamically.
- feature.on_detach(env)
Called by remove_feature(feature).
- feature.on_import(env, node)*
Called whenever a node is imported into env, which is
just before the node is actually connected to the graph.
- feature.on_prune(env, node)*
Called whenever a node is pruned (removed) from the env,
after it is disconnected from the graph.
- feature.on_change_input(env, node, i, r, new_r)*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(env)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
""" """
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
...@@ -256,6 +283,11 @@ class Env(object): #(graph.Graph): ...@@ -256,6 +283,11 @@ class Env(object): #(graph.Graph):
raise raise
def remove_feature(self, feature): def remove_feature(self, feature):
"""
Removes the feature from the graph.
Calls feature.on_detach(env) if an on_detach method is defined.
"""
try: try:
self._features.remove(feature) self._features.remove(feature)
except: except:
...@@ -268,6 +300,11 @@ class Env(object): #(graph.Graph): ...@@ -268,6 +300,11 @@ class Env(object): #(graph.Graph):
### callback utils ### ### callback utils ###
def execute_callbacks(self, name, *args): def execute_callbacks(self, name, *args):
"""
Calls
getattr(feature, name)(*args)
for each feature which has a method called after name.
"""
for feature in self._features: for feature in self._features:
try: try:
fn = getattr(feature, name) fn = getattr(feature, name)
...@@ -276,6 +313,11 @@ class Env(object): #(graph.Graph): ...@@ -276,6 +313,11 @@ class Env(object): #(graph.Graph):
fn(self, *args) fn(self, *args)
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
"""
Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name.
"""
d = {} d = {}
for feature in self._features: for feature in self._features:
try: try:
...@@ -289,6 +331,17 @@ class Env(object): #(graph.Graph): ...@@ -289,6 +331,17 @@ class Env(object): #(graph.Graph):
### misc ### ### misc ###
def toposort(self): def toposort(self):
"""
Returns an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
an 'orderings' method.
If a feature has an 'orderings' method, it will be called with
this env as sole argument. It should return a dictionary of
{node: predecessors} where predecessors is a list of nodes
that should be computed before the key node.
"""
env = self env = self
ords = {} ords = {}
for feature in env._features: for feature in env._features:
...@@ -314,10 +367,10 @@ class Env(object): #(graph.Graph): ...@@ -314,10 +367,10 @@ class Env(object): #(graph.Graph):
raise Exception("what the fuck") raise Exception("what the fuck")
return node.inputs return node.inputs
def has_node(self, node):
return node in self.nodes
def check_integrity(self): def check_integrity(self):
"""
Call this for a diagnosis if things go awry.
"""
nodes = graph.ops(self.inputs, self.outputs) nodes = graph.ops(self.inputs, self.outputs)
if self.nodes != nodes: if self.nodes != nodes:
missing = nodes.difference(self.nodes) missing = nodes.difference(self.nodes)
......
#from features import Listener, Constraint, Orderings, Tool from collections import defaultdict
import graph import graph
import utils import utils
from utils import AbstractFunctionError import toolbox
from copy import copy from utils import AbstractFunctionError
from env import InconsistencyError from env import InconsistencyError
from toolbox import Bookkeeper
from collections import defaultdict
class DestroyHandler(toolbox.Bookkeeper):
class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
""" """
This feature ensures that an env represents a consistent data flow This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over when some Ops overwrite their inputs and/or provide "views" over
...@@ -29,14 +21,16 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -29,14 +21,16 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
Examples: Examples:
- (x += 1) + (x += 1) -> fails because the first += makes the second - (x += 1) + (x += 1) -> fails because the first += makes the second
invalid invalid
- x += transpose_view(x) -> fails because the input that is destroyed
depends on an input that shares the same data
- (a += b) + (c += a) -> succeeds but we have to do c += a first - (a += b) + (c += a) -> succeeds but we have to do c += a first
- (a += b) + (b += c) + (c += a) -> fails because there's a cyclical - (a += b) + (b += c) + (c += a) -> fails because there's a cyclical
dependency (no possible ordering) dependency (no possible ordering)
This feature allows some optimizations (eg sub += for +) to be applied This feature allows some optimizations (eg sub += for +) to be applied
safely. safely.
@todo
- x += transpose_view(x) -> fails because the input that is destroyed
depends on an input that shares the same data
""" """
def __init__(self): def __init__(self):
...@@ -88,11 +82,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -88,11 +82,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
self.seen = set() self.seen = set()
Bookkeeper.on_attach(self, env) toolbox.Bookkeeper.on_attach(self, env)
# # Initialize the children if the inputs and orphans.
# for input in env.inputs: # env.orphans.union(env.inputs):
# self.children[input] = set()
def on_detach(self, env): def on_detach(self, env):
del self.parent del self.parent
...@@ -105,19 +95,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -105,19 +95,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
del self.seen del self.seen
self.env = None self.env = None
# def publish(self):
# """
# Publishes the following on the env:
# - destroyers(r) -> returns all L{Op}s that destroy the result r
# - destroy_handler -> self
# """
# def __destroyers(r):
# ret = self.destroyers.get(r, {})
# ret = ret.keys()
# return ret
# self.env.destroyers = __destroyers
# self.env.destroy_handler = self
def __path__(self, r): def __path__(self, r):
""" """
Returns a path from r to the result that it is ultimately Returns a path from r to the result that it is ultimately
...@@ -171,7 +148,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -171,7 +148,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def __pre__(self, op): def __pre__(self, op):
""" """
Returns all results that must be computed prior to computing Returns all results that must be computed prior to computing
this op. this node.
""" """
rval = set() rval = set()
if op is None: if op is None:
...@@ -222,7 +199,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -222,7 +199,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
users = set(self.__users__(start)) users = set(self.__users__(start))
users.add(start) users.add(start)
for user in users: for user in users:
for cycle in copy(self.cycles): for cycle in set(self.cycles):
if user in cycle: if user in cycle:
self.cycles.remove(cycle) self.cycles.remove(cycle)
if just_remove: if just_remove:
...@@ -234,7 +211,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -234,7 +211,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
""" """
@return: (vmap, dmap) where: @return: (vmap, dmap) where:
- vmap -> {output : [inputs output is a view of]} - vmap -> {output : [inputs output is a view of]}
- dmap -> {output : [inputs that are destroyed by the Op - dmap -> {output : [inputs that are destroyed by the node
(and presumably returned as that output)]} (and presumably returned as that output)]}
""" """
try: _vmap = node.op.view_map try: _vmap = node.op.view_map
...@@ -260,7 +237,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -260,7 +237,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def on_import(self, env, op): def on_import(self, env, op):
""" """
Recomputes the dependencies and search for inconsistencies given Recomputes the dependencies and search for inconsistencies given
that we just added an op to the env. that we just added an node to the env.
""" """
self.seen.add(op) self.seen.add(op)
...@@ -303,7 +280,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -303,7 +280,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def on_prune(self, env, op): def on_prune(self, env, op):
""" """
Recomputes the dependencies and searches for inconsistencies to remove Recomputes the dependencies and searches for inconsistencies to remove
given that we just removed an op to the env. given that we just removed a node to the env.
""" """
view_map, destroy_map = self.get_maps(op) view_map, destroy_map = self.get_maps(op)
...@@ -400,7 +377,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -400,7 +377,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
""" """
Recomputes the dependencies and searches for inconsistencies to remove Recomputes the dependencies and searches for inconsistencies to remove
given that all the clients are moved from r_1 to r_2, clients being given that all the clients are moved from r_1 to r_2, clients being
a list of (op, i) pairs such that op.inputs[i] used to be r_1 and is a list of (node, i) pairs such that node.inputs[i] used to be r_1 and is
now r_2. now r_2.
""" """
path_1 = self.__path__(r_1) path_1 = self.__path__(r_1)
...@@ -485,53 +462,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -485,53 +462,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
# class Destroyer:
# """
# Base class for Ops that destroy one or more of their inputs in an
# inplace operation, use them as temporary storage, puts garbage in
# them or anything else that invalidates the contents for use by other
# Ops.
# Usage of this class in an env requires DestroyHandler.
# """
# def destroyed_inputs(self):
# raise AbstractFunctionError()
# def destroy_map(self):
# """
# Returns the map {output: [list of destroyed inputs]}
# While it typically means that the storage of the output is
# shared with each of the destroyed inputs, it does necessarily
# have to be the case.
# """
# # compatibility
# return {self.out: self.destroyed_inputs()}
# __env_require__ = DestroyHandler
# class Viewer:
# """
# Base class for Ops that return one or more views over one or more inputs,
# which means that the inputs and outputs share their storage. Unless it also
# extends Destroyer, this Op does not modify the storage in any way and thus
# the input is safe for use by other Ops even after executing this one.
# """
# def view_map(self):
# """
# Returns the map {output: [list of viewed inputs]}
# It means that the output shares storage with each of the inputs
# in the list.
# Note: support for more than one viewed input is minimal, but
# this might improve in the future.
# """
# raise AbstractFunctionError()
def view_roots(r): def view_roots(r):
""" """
Utility function that returns the leaves of a search through Utility function that returns the leaves of a search through
......
...@@ -3,20 +3,9 @@ from copy import copy ...@@ -3,20 +3,9 @@ from copy import copy
from collections import deque from collections import deque
import utils import utils
from utils import object2
def deprecated(f):
printme = [True]
def g(*args, **kwargs):
if printme[0]:
print 'gof.graph.%s deprecated: April 29' % f.__name__
printme[0] = False
return f(*args, **kwargs)
return g
class Apply(utils.object2):
class Apply(object2):
""" """
Note: it is illegal for an output element to have an owner != self Note: it is illegal for an output element to have an owner != self
""" """
...@@ -74,18 +63,13 @@ class Apply(object2): ...@@ -74,18 +63,13 @@ class Apply(object2):
raise TypeError("Cannot change the type of this input.", curr, new) raise TypeError("Cannot change the type of this input.", curr, new)
new_node = self.clone() new_node = self.clone()
new_node.inputs = inputs new_node.inputs = inputs
# new_node.outputs = []
# for output in self.outputs:
# new_output = copy(output)
# new_output.owner = new_node
# new_node.outputs.append(new_output)
return new_node return new_node
nin = property(lambda self: len(self.inputs)) nin = property(lambda self: len(self.inputs))
nout = property(lambda self: len(self.outputs)) nout = property(lambda self: len(self.outputs))
class Result(object2): class Result(utils.object2):
#__slots__ = ['type', 'owner', 'index', 'name'] #__slots__ = ['type', 'owner', 'index', 'name']
def __init__(self, type, owner = None, index = None, name = None): def __init__(self, type, owner = None, index = None, name = None):
self.type = type self.type = type
...@@ -111,9 +95,6 @@ class Result(object2): ...@@ -111,9 +95,6 @@ class Result(object2):
return "<?>::" + str(self.type) return "<?>::" + str(self.type)
def __repr__(self): def __repr__(self):
return str(self) return str(self)
@deprecated
def __asresult__(self):
return self
def clone(self): def clone(self):
return self.__class__(self.type, None, None, self.name) return self.__class__(self.type, None, None, self.name)
...@@ -137,7 +118,6 @@ class Constant(Value): ...@@ -137,7 +118,6 @@ class Constant(Value):
#__slots__ = ['data'] #__slots__ = ['data']
def __init__(self, type, data, name = None): def __init__(self, type, data, name = None):
Value.__init__(self, type, data, name) Value.__init__(self, type, data, name)
### self.indestructible = True
def equals(self, other): def equals(self, other):
# this does what __eq__ should do, but Result and Apply should always be hashable by id # this does what __eq__ should do, but Result and Apply should always be hashable by id
return type(other) == type(self) and self.signature() == other.signature() return type(other) == type(self) and self.signature() == other.signature()
...@@ -148,32 +128,6 @@ class Constant(Value): ...@@ -148,32 +128,6 @@ class Constant(Value):
return self.name return self.name
return str(self.data) #+ "::" + str(self.type) return str(self.data) #+ "::" + str(self.type)
@deprecated
def as_result(x):
if isinstance(x, Result):
return x
# elif isinstance(x, Type):
# return Result(x, None, None)
elif hasattr(x, '__asresult__'):
r = x.__asresult__()
if not isinstance(r, Result):
raise TypeError("%s.__asresult__ must return a Result instance" % x, (x, r))
return r
else:
raise TypeError("Cannot wrap %s in a Result" % x)
@deprecated
def as_apply(x):
if isinstance(x, Apply):
return x
elif hasattr(x, '__asapply__'):
node = x.__asapply__()
if not isinstance(node, Apply):
raise TypeError("%s.__asapply__ must return an Apply instance" % x, (x, node))
return node
else:
raise TypeError("Cannot map %s to Apply" % x)
def stack_search(start, expand, mode='bfs', build_inv = False): def stack_search(start, expand, mode='bfs', build_inv = False):
"""Search through L{Result}s, either breadth- or depth-first """Search through L{Result}s, either breadth- or depth-first
@type start: deque @type start: deque
...@@ -234,45 +188,6 @@ def inputs(result_list): ...@@ -234,45 +188,6 @@ def inputs(result_list):
return rval return rval
# def results_and_orphans(r_in, r_out, except_unreachable_input=False):
# r_in_set = set(r_in)
# class Dummy(object): pass
# dummy = Dummy()
# dummy.inputs = r_out
# def expand_inputs(io):
# if io in r_in_set:
# return None
# try:
# return [io.owner] if io.owner != None else None
# except AttributeError:
# return io.inputs
# ops_and_results, dfsinv = stack_search(
# deque([dummy]),
# expand_inputs, 'dfs', True)
# if except_unreachable_input:
# for r in r_in:
# if r not in dfsinv:
# raise Exception(results_and_orphans.E_unreached)
# clients = stack_search(
# deque(r_in),
# lambda io: dfsinv.get(io,None), 'dfs')
# ops_to_compute = [o for o in clients if is_op(o) and o is not dummy]
# results = []
# for o in ops_to_compute:
# results.extend(o.inputs)
# results.extend(r_out)
# op_set = set(ops_to_compute)
# assert len(ops_to_compute) == len(op_set)
# orphans = [r for r in results \
# if (r.owner not in op_set) and (r not in r_in_set)]
# return results, orphans
# results_and_orphans.E_unreached = 'there were unreachable inputs'
def results_and_orphans(i, o): def results_and_orphans(i, o):
""" """
""" """
...@@ -286,24 +201,6 @@ def results_and_orphans(i, o): ...@@ -286,24 +201,6 @@ def results_and_orphans(i, o):
return results, orphans return results, orphans
#def results_and_orphans(i, o):
# results = set()
# orphans = set()
# def helper(r):
# if r in results:
# return
# results.add(r)
# if r.owner is None:
# if r not in i:
# orphans.add(r)
# else:
# for r2 in r.owner.inputs:
# helper(r2)
# for output in o:
# helper(output)
# return results, orphans
def ops(i, o): def ops(i, o):
""" """
@type i: list @type i: list
...@@ -569,122 +466,3 @@ def as_string(i, o, ...@@ -569,122 +466,3 @@ def as_string(i, o,
return [describe(output) for output in o] return [describe(output) for output in o]
# class Graph:
# """
# Object-oriented wrapper for all the functions in this module.
# """
# def __init__(self, inputs, outputs):
# self.inputs = inputs
# self.outputs = outputs
# def ops(self):
# return ops(self.inputs, self.outputs)
# def values(self):
# return values(self.inputs, self.outputs)
# def orphans(self):
# return orphans(self.inputs, self.outputs)
# def io_toposort(self):
# return io_toposort(self.inputs, self.outputs)
# def toposort(self):
# return self.io_toposort()
# def clone(self):
# o = clone(self.inputs, self.outputs)
# return Graph(self.inputs, o)
# def __str__(self):
# return as_string(self.inputs, self.outputs)
if 0:
#these were the old implementations
# they were replaced out of a desire that graph search routines would not
# depend on the hash or id of any node, so that it would be deterministic
# and consistent between program executions.
@utils.deprecated('gof.graph', 'preserving only for review')
def _results_and_orphans(i, o, except_unreachable_input=False):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results = set()
i = set(i)
results.update(i)
incomplete_paths = []
reached = set()
def helper(r, path):
if r in i:
reached.add(r)
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [output])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
if except_unreachable_input and len(i) != len(reached):
raise Exception(results_and_orphans.E_unreached)
results.update(orphans)
return results, orphans
def _io_toposort(i, o, orderings = {}):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
@param orderings: {op: [requirements for op]} (defaults to {})
@rtype: ordered list
@return: L{Op}s that belong in the subgraph between i and o which
respects the following constraints:
- all inputs in i are assumed to be already computed
- the L{Op}s that compute an L{Op}'s inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
asdf = set([input.owner for input in op.inputs if input.owner and input.owner in all])
prereqs_d.setdefault(op, set()).update(asdf)
return utils.toposort(prereqs_d)
from utils import AbstractFunctionError
import utils import utils
import graph
from graph import Value import sys, traceback
import sys
import traceback
__excepthook = sys.excepthook __excepthook = sys.excepthook
...@@ -67,7 +64,7 @@ class Linker: ...@@ -67,7 +64,7 @@ class Linker:
print new_e.data # 3.0 print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown) print e.data # 3.0 iff inplace == True (else unknown)
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
def make_function(self, unpack_single = True, **kwargs): def make_function(self, unpack_single = True, **kwargs):
""" """
...@@ -151,7 +148,7 @@ def map_storage(env, order, input_storage, output_storage): ...@@ -151,7 +148,7 @@ def map_storage(env, order, input_storage, output_storage):
for node in order: for node in order:
for r in node.inputs: for r in node.inputs:
if r not in storage_map: if r not in storage_map:
assert isinstance(r, Value) assert isinstance(r, graph.Value)
storage_map[r] = [r.data] storage_map[r] = [r.data]
for r in node.outputs: for r in node.outputs:
storage_map.setdefault(r, [None]) storage_map.setdefault(r, [None])
......
""" """
Contains the L{Op} class, which is the base interface for all operations Contains the L{Op} class, which is the base interface for all operations
compatible with gof's graph manipulation routines. compatible with gof's graph manipulation routines.
""" """
import utils import utils
from utils import AbstractFunctionError, object2
from copy import copy
class Op(object2): class Op(utils.object2):
default_output = None default_output = None
"""@todo """@todo
...@@ -22,9 +17,19 @@ class Op(object2): ...@@ -22,9 +17,19 @@ class Op(object2):
############# #############
def make_node(self, *inputs): def make_node(self, *inputs):
raise AbstractFunctionError() """
This function should return an Apply instance representing the
application of this Op on the provided inputs.
"""
raise utils.AbstractFunctionError()
def __call__(self, *inputs): def __call__(self, *inputs):
"""
Shortcut for:
self.make_node(*inputs).outputs[self.default_output] (if default_output is defined)
self.make_node(*inputs).outputs[0] (if only one output)
self.make_node(*inputs).outputs (if more than one output)
"""
node = self.make_node(*inputs) node = self.make_node(*inputs)
if self.default_output is not None: if self.default_output is not None:
return node.outputs[self.default_output] return node.outputs[self.default_output]
...@@ -44,6 +49,7 @@ class Op(object2): ...@@ -44,6 +49,7 @@ class Op(object2):
Calculate the function on the inputs and put the results in the Calculate the function on the inputs and put the results in the
output storage. output storage.
- node: Apply instance that contains the symbolic inputs and outputs
- inputs: sequence of inputs (immutable) - inputs: sequence of inputs (immutable)
- output_storage: list of mutable 1-element lists (do not change - output_storage: list of mutable 1-element lists (do not change
the length of these lists) the length of these lists)
...@@ -53,7 +59,7 @@ class Op(object2): ...@@ -53,7 +59,7 @@ class Op(object2):
by a previous call to impl and impl is free to reuse it as it by a previous call to impl and impl is free to reuse it as it
sees fit. sees fit.
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
##################### #####################
# C code generation # # C code generation #
...@@ -62,9 +68,8 @@ class Op(object2): ...@@ -62,9 +68,8 @@ class Op(object2):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
"""Return the C implementation of an Op. """Return the C implementation of an Op.
Returns templated C code that does the computation associated Returns C code that does the computation associated to this L{Op},
to this L{Op}. You may assume that input validation and output given names for the inputs and outputs.
allocation have already been done.
@param inputs: list of strings. There is a string for each input @param inputs: list of strings. There is a string for each input
of the function, and the string is the name of a C of the function, and the string is the name of a C
...@@ -80,7 +85,7 @@ class Op(object2): ...@@ -80,7 +85,7 @@ class Op(object2):
'fail'). 'fail').
""" """
raise AbstractFunctionError('%s.c_code' \ raise utils.AbstractFunctionError('%s.c_code is not defined' \
% self.__class__.__name__) % self.__class__.__name__)
def c_code_cleanup(self, node, name, inputs, outputs, sub): def c_code_cleanup(self, node, name, inputs, outputs, sub):
...@@ -89,44 +94,33 @@ class Op(object2): ...@@ -89,44 +94,33 @@ class Op(object2):
This is a convenient place to clean up things allocated by c_code(). This is a convenient place to clean up things allocated by c_code().
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
def c_compile_args(self): def c_compile_args(self):
""" """
Return a list of compile args recommended to manipulate this L{Op}. Return a list of compile args recommended to manipulate this L{Op}.
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
def c_headers(self): def c_headers(self):
""" """
Return a list of header files that must be included from C to manipulate Return a list of header files that must be included from C to manipulate
this L{Op}. this L{Op}.
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
def c_libraries(self): def c_libraries(self):
""" """
Return a list of libraries to link against to manipulate this L{Op}. Return a list of libraries to link against to manipulate this L{Op}.
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
def c_support_code(self): def c_support_code(self):
""" """
Return utility code for use by this L{Op}. It may refer to support code Return utility code for use by this L{Op}. It may refer to support code
defined for its input L{Result}s. defined for its input L{Result}s.
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
class PropertiedOp(Op):
def __eq__(self, other):
return type(self) == type(other) and self.__dict__ == other.__dict__
def __str__(self):
if hasattr(self, 'name') and self.name:
return self.name
else:
return "%s{%s}" % (self.__class__.__name__, ", ".join("%s=%s" % (k, v) for k, v in self.__dict__.items() if k != "name"))
"""
Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
"""
from op import Op
from graph import Constant import graph
from type import Type
from env import InconsistencyError from env import InconsistencyError
import utils import utils
import unify import unify
import toolbox import toolbox
import ext
class Optimizer: class Optimizer:
...@@ -20,9 +22,8 @@ class Optimizer: ...@@ -20,9 +22,8 @@ class Optimizer:
""" """
Applies the optimization to the provided L{Env}. It may use all Applies the optimization to the provided L{Env}. It may use all
the methods defined by the L{Env}. If the L{Optimizer} needs the methods defined by the L{Env}. If the L{Optimizer} needs
to use a certain tool, such as an L{InstanceFinder}, it should to use a certain tool, such as an L{InstanceFinder}, it can do
set the L{__env_require__} field to a list of what needs to be so in its L{add_requirements} method.
registered with the L{Env}.
""" """
pass pass
...@@ -36,9 +37,19 @@ class Optimizer: ...@@ -36,9 +37,19 @@ class Optimizer:
self.apply(env) self.apply(env)
def __call__(self, env): def __call__(self, env):
"""
Same as self.optimize(env)
"""
return self.optimize(env) return self.optimize(env)
def add_requirements(self, env): def add_requirements(self, env):
"""
Add features to the env that are required to apply the optimization.
For example:
env.extend(History())
env.extend(MyFeature())
etc.
"""
pass pass
...@@ -79,7 +90,7 @@ class LocalOptimizer(Optimizer): ...@@ -79,7 +90,7 @@ class LocalOptimizer(Optimizer):
following two methods: following two methods:
- candidates(env) -> returns a set of ops that can be - candidates(env) -> returns a set of ops that can be
optimized optimized
- apply_on_op(env, op) -> for each op in candidates, - apply_on_node(env, node) -> for each node in candidates,
this function will be called to perform the actual this function will be called to perform the actual
optimization. optimization.
""" """
...@@ -102,7 +113,7 @@ class LocalOptimizer(Optimizer): ...@@ -102,7 +113,7 @@ class LocalOptimizer(Optimizer):
Calls self.apply_on_op(env, op) for each op in self.candidates(env). Calls self.apply_on_op(env, op) for each op in self.candidates(env).
""" """
for node in self.candidates(env): for node in self.candidates(env):
if env.has_node(node): if node in env.nodes:
self.apply_on_node(env, node) self.apply_on_node(env, node)
...@@ -122,7 +133,7 @@ class OpSpecificOptimizer(LocalOptimizer): ...@@ -122,7 +133,7 @@ class OpSpecificOptimizer(LocalOptimizer):
def candidates(self, env): def candidates(self, env):
""" """
Returns all instances of L{self.op}. Returns all nodes that have L{self.op} in their op field.
""" """
return env.get_nodes(self.op) return env.get_nodes(self.op)
...@@ -131,13 +142,22 @@ class OpSpecificOptimizer(LocalOptimizer): ...@@ -131,13 +142,22 @@ class OpSpecificOptimizer(LocalOptimizer):
class OpSubOptimizer(Optimizer): class OpSubOptimizer(Optimizer):
""" """
Replaces all L{Op}s of a certain type by L{Op}s of another type that Replaces all applications of a certain op by the application of
take the same inputs as what they are replacing. another op that take the same inputs as what they are replacing.
e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x)) e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
OpSubOptimizer requires the following features:
- NodeFinder
- ReplaceValidate
""" """
def add_requirements(self, env): def add_requirements(self, env):
"""
Requires the following features:
- NodeFinder
- ReplaceValidate
"""
try: try:
env.extend(toolbox.NodeFinder()) env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
...@@ -145,9 +165,12 @@ class OpSubOptimizer(Optimizer): ...@@ -145,9 +165,12 @@ class OpSubOptimizer(Optimizer):
def __init__(self, op1, op2, failure_callback = None): def __init__(self, op1, op2, failure_callback = None):
""" """
op1 and op2 must both be Op subclasses, they must both take op1.make_node and op2.make_node must take the same number of
the same number of inputs and they must both have the same inputs and have the same number of outputs.
number of outputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (node, replacement, exception)
""" """
self.op1 = op1 self.op1 = op1
self.op2 = op2 self.op2 = op2
...@@ -155,12 +178,8 @@ class OpSubOptimizer(Optimizer): ...@@ -155,12 +178,8 @@ class OpSubOptimizer(Optimizer):
def apply(self, env): def apply(self, env):
""" """
Replaces all occurrences of self.op1 by instances of self.op2 Replaces all applications of self.op1 by applications of self.op2
with the same inputs. with the same inputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (op1_instance, replacement, exception)
""" """
candidates = env.get_nodes(self.op1) candidates = env.get_nodes(self.op1)
...@@ -173,7 +192,6 @@ class OpSubOptimizer(Optimizer): ...@@ -173,7 +192,6 @@ class OpSubOptimizer(Optimizer):
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(node, repl, e) self.failure_callback(node, repl, e)
pass
def str(self): def str(self):
return "%s -> %s" % (self.op1, self.op2) return "%s -> %s" % (self.op1, self.op2)
...@@ -183,7 +201,7 @@ class OpSubOptimizer(Optimizer): ...@@ -183,7 +201,7 @@ class OpSubOptimizer(Optimizer):
class OpRemover(Optimizer): class OpRemover(Optimizer):
""" """
@todo untested @todo untested
Removes all ops of a certain type by transferring each of its Removes all applications of an op by transferring each of its
outputs to the corresponding input. outputs to the corresponding input.
""" """
...@@ -195,21 +213,19 @@ class OpRemover(Optimizer): ...@@ -195,21 +213,19 @@ class OpRemover(Optimizer):
def __init__(self, op, failure_callback = None): def __init__(self, op, failure_callback = None):
""" """
opclass is the class of the ops to remove. It must take as Applications of the op must have as many inputs as outputs.
many inputs as outputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (node, exception)
""" """
self.op = op self.op = op
self.failure_callback = failure_callback self.failure_callback = failure_callback
def apply(self, env): def apply(self, env):
""" """
Removes all occurrences of self.opclass. Removes all applications of self.op.
If self.failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (opclass_instance, exception)
""" """
candidates = env.get_nodes(self.op) candidates = env.get_nodes(self.op)
for node in candidates: for node in candidates:
...@@ -231,17 +247,17 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -231,17 +247,17 @@ class PatternOptimizer(OpSpecificOptimizer):
""" """
@todo update @todo update
Replaces all occurrences of the input pattern by the output pattern:: Replaces all occurrences of the input pattern by the output pattern:
input_pattern ::= (OpClass, <sub_pattern1>, <sub_pattern2>, ...) input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>, input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>) constraint = <constraint>)
sub_pattern ::= input_pattern sub_pattern ::= input_pattern
sub_pattern ::= string sub_pattern ::= string
sub_pattern ::= a Result r such that r.constant is True sub_pattern ::= a Constant instance
constraint ::= lambda env, expr: additional matching condition constraint ::= lambda env, expr: additional matching condition
output_pattern ::= (OpClass, <output_pattern1>, <output_pattern2>, ...) output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string output_pattern ::= string
Each string in the input pattern is a variable that will be set to Each string in the input pattern is a variable that will be set to
...@@ -253,8 +269,8 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -253,8 +269,8 @@ class PatternOptimizer(OpSpecificOptimizer):
pattern can. pattern can.
If you put a constant result in the input pattern, there will be a If you put a constant result in the input pattern, there will be a
match iff a constant result with the same value is found in its match iff a constant result with the same value and the same type
place. is found in its place.
You can add a constraint to the match by using the dict(...) form You can add a constraint to the match by using the dict(...) form
described above with a 'constraint' key. The constraint must be a described above with a 'constraint' key. The constraint must be a
...@@ -263,16 +279,27 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -263,16 +279,27 @@ class PatternOptimizer(OpSpecificOptimizer):
arbitrary criterion. arbitrary criterion.
Examples: Examples:
PatternOptimizer((Add, 'x', 'y'), (Add, 'y', 'x')) PatternOptimizer((add, 'x', 'y'), (add, 'y', 'x'))
PatternOptimizer((Multiply, 'x', 'x'), (Square, 'x')) PatternOptimizer((multiply, 'x', 'x'), (square, 'x'))
PatternOptimizer((Subtract, (Add, 'x', 'y'), 'y'), 'x') PatternOptimizer((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternOptimizer((Power, 'x', Double(2.0, constant = True)), (Square, 'x')) PatternOptimizer((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternOptimizer((Boggle, {'pattern': 'x', PatternOptimizer((boggle, {'pattern': 'x',
'constraint': lambda env, expr: expr.owner.scrabble == True}), 'constraint': lambda env, expr: expr.type == scrabble}),
(Scrabble, 'x')) (scrabble, 'x'))
""" """
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None): def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None):
"""
Creates a PatternOptimizer that replaces occurrences of
in_pattern by occurrences of out_pattern.
If failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (result_to_replace, replacement, exception).
If allow_multiple_clients is False, he pattern matching will
fail if one of the subpatterns has more than one client.
"""
self.in_pattern = in_pattern self.in_pattern = in_pattern
self.out_pattern = out_pattern self.out_pattern = out_pattern
if isinstance(in_pattern, (list, tuple)): if isinstance(in_pattern, (list, tuple)):
...@@ -287,15 +314,8 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -287,15 +314,8 @@ class PatternOptimizer(OpSpecificOptimizer):
def apply_on_node(self, env, node): def apply_on_node(self, env, node):
""" """
Checks if the graph from op corresponds to in_pattern. If it does, Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement. constructs out_pattern and performs the replacement.
If self.failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (results_to_replace, replacement, exception).
If self.allow_multiple_clients is False, he pattern matching will fail
if one of the subpatterns has more than one client.
""" """
def match(pattern, expr, u, first = False): def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
...@@ -323,7 +343,7 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -323,7 +343,7 @@ class PatternOptimizer(OpSpecificOptimizer):
return False return False
else: else:
u = u.merge(expr, v) u = u.merge(expr, v)
elif isinstance(pattern, Constant) and isinstance(expr, Constant) and pattern.equals(expr): elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr):
return u return u
else: else:
return False return False
...@@ -363,28 +383,6 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -363,28 +383,6 @@ class PatternOptimizer(OpSpecificOptimizer):
# class ConstantFinder(Optimizer):
# """
# Sets as constant every orphan that is not destroyed.
# """
# def apply(self, env):
# if env.has_feature(ext.DestroyHandler(env)):
# for r in env.orphans():
# if not env.destroyers(r):
# r.indestructible = True
# r.constant = True
# # for r in env.inputs:
# # if not env.destroyers(r):
# # r.indestructible = True
# else:
# for r in env.orphans():
# r.indestructible = True
# r.constant = True
# # for r in env.inputs:
# # r.indestructible = True
import graph
class _metadict: class _metadict:
# dict that accepts unhashable keys # dict that accepts unhashable keys
...@@ -438,15 +436,14 @@ class MergeOptimizer(Optimizer): ...@@ -438,15 +436,14 @@ class MergeOptimizer(Optimizer):
def apply(self, env): def apply(self, env):
cid = _metadict() #result -> result.desc() (for constants) cid = _metadict() #result -> result.desc() (for constants)
inv_cid = _metadict() #desc -> result (for constants) inv_cid = _metadict() #desc -> result (for constants)
for i, r in enumerate([r for r in env.results if isinstance(r, Constant)]): #env.orphans.union(env.inputs)): for i, r in enumerate([r for r in env.results if isinstance(r, graph.Constant)]):
#if isinstance(r, Constant): sig = r.signature()
sig = r.signature() other_r = inv_cid.get(sig, None)
other_r = inv_cid.get(sig, None) if other_r is not None:
if other_r is not None: env.replace(r, other_r)
env.replace(r, other_r) else:
else: cid[r] = sig
cid[r] = sig inv_cid[sig] = r
inv_cid[sig] = r
# we clear the dicts because the Constants signatures are not necessarily hashable # we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer cid like the other Results # and it's more efficient to give them an integer cid like the other Results
cid.clear() cid.clear()
...@@ -483,123 +480,3 @@ def MergeOptMerge(opt): ...@@ -483,123 +480,3 @@ def MergeOptMerge(opt):
merger = MergeOptimizer() merger = MergeOptimizer()
return SeqOptimizer([merger, opt, merger]) return SeqOptimizer([merger, opt, merger])
### THE FOLLOWING OPTIMIZERS ARE NEITHER USED NOR TESTED BUT PROBABLY WORK AND COULD BE USEFUL ###
# class MultiOptimizer(Optimizer):
# def __init__(self, **opts):
# self._opts = []
# self.ord = {}
# self.name_to_opt = {}
# self.up_to_date = True
# for name, opt in opts:
# self.register(name, opt, after = [], before = [])
# def register(self, name, opt, **relative):
# self.name_to_opt[name] = opt
# after = relative.get('after', [])
# if not isinstance(after, (list, tuple)):
# after = [after]
# before = relative.get('before', [])
# if not isinstance(before, (list, tuple)):
# before = [before]
# self.up_to_date = False
# if name in self.ord:
# raise Exception("Cannot redefine optimization: '%s'" % name)
# self.ord[name] = set(after)
# for postreq in before:
# self.ord.setdefault(postreq, set()).add(name)
# def get_opts(self):
# if not self.up_to_date:
# self.refresh()
# return self._opts
# def refresh(self):
# self._opts = [self.name_to_opt[name] for name in utils.toposort(self.ord)]
# self.up_to_date = True
# def apply(self, env):
# for opt in self.opts:
# opt.apply(env)
# opts = property(get_opts)
# class TaggedMultiOptimizer(MultiOptimizer):
# def __init__(self, **opts):
# self.tags = {}
# MultiOptimizer.__init__(self, **opts)
# def register(self, name, opt, tags = [], **relative):
# tags = set(tags)
# tags.add(name)
# self.tags[opt] = tags
# MultiOptimizer.register(self, name, opt, **relative)
# def filter(self, whitelist, blacklist):
# return [opt for opt in self.opts
# if self.tags[opt].intersection(whitelist)
# and not self.tags[opt].intersection(blacklist)]
# def whitelist(self, *tags):
# return [opt for opt in self.opts if self.tags[opt].intersection(tags)]
# def blacklist(self, *tags):
# return [opt for opt in self.opts if not self.tags[opt].intersection(tags)]
# class TagFilterMultiOptimizer(Optimizer):
# def __init__(self, all, whitelist = None, blacklist = None):
# self.all = all
# if whitelist is not None:
# self.whitelist = set(whitelist)
# else:
# self.whitelist = None
# if blacklist is not None:
# self.blacklist = set(blacklist)
# else:
# self.blacklist = set()
# def use_whitelist(self, use = True):
# if self.whitelist is None and use:
# self.whitelist = set()
# def allow(self, *tags):
# if self.whitelist is not None:
# self.whitelist.update(tags)
# self.blacklist.difference_update(tags)
# def deny(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.update(tags)
# def dont_care(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.difference_update(tags)
# def opts(self):
# if self.whitelist is not None:
# return self.all.filter(self.whitelist, self.blacklist)
# else:
# return self.all.blacklist(*[tag for tag in self.blacklist])
# def apply(self, env):
# for opt in self.opts():
# opt.apply(env)
from random import shuffle
import utils
from functools import partial from functools import partial
import graph import graph
...@@ -14,51 +12,6 @@ class Bookkeeper: ...@@ -14,51 +12,6 @@ class Bookkeeper:
def on_detach(self, env): def on_detach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs): for node in graph.io_toposort(env.inputs, env.outputs):
self.on_prune(env, node) self.on_prune(env, node)
# class Toposorter:
# def on_attach(self, env):
# if hasattr(env, 'toposort'):
# raise Exception("Toposorter feature is already present or in conflict with another plugin.")
# env.toposort = partial(self.toposort, env)
# def on_detach(self, env):
# del env.toposort
# def toposort(self, env):
# ords = {}
# for feature in env._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings(env).items():
# ords.setdefault(op, set()).update(prereqs)
# order = graph.io_toposort(env.inputs, env.outputs, ords)
# return order
# def supplemental_orderings(self):
# """
# Returns a dictionary of {op: set(prerequisites)} that must
# be satisfied in addition to the order defined by the structure
# of the graph (returns orderings that not related to input/output
# relationships).
# """
# ords = {}
# for feature in self._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings().items():
# ords.setdefault(op, set()).update(prereqs)
# return ords
# def toposort(self):
# """
# Returns a list of nodes in the order that they must be executed
# in order to preserve the semantics of the graph and respect
# the constraints put forward by the listeners.
# """
# ords = self.supplemental_orderings()
# order = graph.io_toposort(self.inputs, self.outputs, ords)
# return order
class History: class History:
...@@ -223,151 +176,3 @@ class PrintListener(object): ...@@ -223,151 +176,3 @@ class PrintListener(object):
# class EquivTool(dict):
# def __init__(self, env):
# self.env = env
# def on_rewire(self, clients, r, new_r):
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
# def publish(self):
# self.env.equiv = self
# self.env.set_equiv = self.set_equiv
# def unpublish(self):
# del self.env.equiv
# del self.env.set_equiv
# def set_equiv(self, d):
# self.update(d)
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
# def all_bases(self, cls):
# return utils.all_bases(cls, lambda cls: cls is not object)
# def on_import(self, op):
# for base in self.all_bases(op.__class__):
# self.setdefault(base, set()).add(op)
# def on_prune(self, op):
# for base in self.all_bases(op.__class__):
# self[base].remove(op)
# if not self[base]:
# del self[base]
# def __query__(self, cls):
# all = [x for x in self.get(cls, [])]
# shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, cls):
# return self.__query__(cls)
# def publish(self):
# self.env.get_instances_of = self.query
# class DescFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
# def on_import(self, op):
# self.setdefault(op.desc(), set()).add(op)
# def on_prune(self, op):
# desc = op.desc()
# self[desc].remove(op)
# if not self[desc]:
# del self[desc]
# def __query__(self, desc):
# all = [x for x in self.get(desc, [])]
# shuffle(all) # this helps for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, desc):
# return self.__query__(desc)
# def publish(self):
# self.env.get_from_desc = self.query
### UNUSED AND UNTESTED ###
# class ChangeListener(Listener):
# def __init__(self, env):
# self.change = False
# def on_import(self, op):
# self.change = True
# def on_prune(self, op):
# self.change = True
# def on_rewire(self, clients, r, new_r):
# self.change = True
# def __call__(self, value = "get"):
# if value == "get":
# return self.change
# else:
# self.change = value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论