提交 2aa7089e authored 作者: Olivier Breuleux's avatar Olivier Breuleux

adding gof package as a subdirectory

上级 430999a9
# from op import *
# from value import *
# from opt import *
# from env import *
# from prog import *
# from diff import *
# import dispatchers
from op import *
from ext import *
from lib import *
from link import *
from result import *
from env import *
from prog import *
from features import *
from opt import *
import graph
#import utils
import env
import tools
import utils
class Compiler:
def __init__(self, optimizer, features):
self.features = set(features)
self.features.update(optimizer.require())
self.optimizer = optimizer
def compile(self, inputs, outputs, features):
features = self.features.union(features)
e = env.Env(inputs, outputs, features, False)
self.optimizer.apply(e)
if not e.consistent():
raise env.InconsistencyError("The graph is inconsistent.")
return e
def __call__(self, inputs, outputs, features):
return self.compile(inputs, outputs, features)
# def __init__(self, inputs, outputs, preprocessors, features, optimizer):
# self.inputs = inputs
# self.outputs = outputs
# self.features = features
# self.optimizer = optimizer
# features = features + [tools.EquivTool] + optimizer.require()
# features = utils.uniq_features(features)
# self.env = env.Env(inputs,
# outputs,
# features,
# False)
# if not self.env.consistent():
# raise env.InconsistencyError("The graph is inconsistent.")
# self.__optimize__()
# self.thunks = [op.thunk() for op in self.order]
# def __optimize__(self):
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
# def equiv(self, r):
# return self.env.equiv(r)
# def __getitem__(self, r):
# return self.equiv(r)
# def __setitem__(self, r, value):
# if isinstance(r, tuple):
# for a, b in zip(r, value):
# self.__setitem__(a, b)
# else:
# self.equiv(r).set_value(value)
# def __call__(self, *args):
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
# import env
# import opt
# from value import AsValue
# class Prog:
# def __init__(self, inputs, outputs, optimizer):
# self.inputs = inputs
# self.outputs = outputs
# self.env = env.Env(inputs,
# outputs,
# False,
# op_db = env.OpDb,
# changed = env.ChangeListener,
# # pr = env.PrintListener,
# scope = env.ScopeListener)
# ## self.adjustments = adjustments
# self.optimizer = optimizer
# ## if self.adjustments:
# ## self.adjustments.apply(self.env)
# if not self.env.consistent():
# raise env.InconsistencyError("The graph is inconsistent.")
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
# print "==================="
# for op in self.order:
# print op
# print "==================="
# self.thunks = [op.thunk() for op in self.order]
# def equiv(self, v):
# v = AsValue(v)
# return self.env.equiv(v)
# def __getitem__(self, v):
# return self.equiv(v).storage
# def __setitem__(self, v, value):
# if isinstance(v, tuple):
# for a, b in zip(v, value):
# self.__setitem__(a, b)
# else:
# self.equiv(v).value = value
# def __call__(self, *args):
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
# def prog(i, o):
# if not isinstance(i, (list, tuple)):
# i = [i]
# if not isinstance(o, (list, tuple)):
# o = [o]
# i = [AsValue(input) for input in i]
# o = [AsValue(output) for output in o]
# return Prog(i,
# o,
# opt.TagFilterMultiOptimizer(opt.opt_registry, None, None))
from utils import OmegaError
class OmegaTypeError(OmegaError, TypeError):
pass
############################
# Dispatcher
############################
class Dispatcher(list):
all_dispatchers = {}
def __init__(self, name, description):
self.name = name
self.description = description
self.all_dispatchers[name] = self
def __call__(self, *inputs, **opts):
for candidate in self:
try:
return candidate(*inputs, **opts)
except OmegaTypeError:
continue
if opts:
s = " with options %s" % opts
else:
s = ""
raise OmegaTypeError("No candidate found for %s(%s) %s" \
% (self.name,
", ".join([input.__class__.__name__ for input in inputs]),
s))
def add_handler(self, x):
self.insert(0, x)
def fallback_handler(self, x):
self.append(x)
# Dispatchers for all python operators
Add = Dispatcher("Add", "x + y")
Subtract = Dispatcher("Subtract", "x - y")
Multiply = Dispatcher("Multiply", "x * y")
Divide = Dispatcher("Divide", "x / y")
FloorDivide = Dispatcher("FloorDivide", "x // y")
Modulo = Dispatcher("Modulo", "x % y")
Power = Dispatcher("Power", "x ** y")
Negate = Dispatcher("Negate", "-x")
Abs = Dispatcher("Abs", "abs(x)")
LeftShift = Dispatcher("LeftShift", "x << y")
RightShift = Dispatcher("RightShift", "x >> y")
Equals = Dispatcher("Equals", "x == y")
NotEquals = Dispatcher("NotEquals", "x != y")
Less = Dispatcher("Less", "x < y")
LessOrEqual = Dispatcher("LessOrEqual", "x <= y")
Greater = Dispatcher("Greater", "x > y")
GreaterOrEqual = Dispatcher("GreaterOrEqual", "x >= y")
Contains = Dispatcher("Contains", "x in y")
BinaryOr = Dispatcher("BinaryOr", "x | y")
BinaryAnd = Dispatcher("BinaryAnd", "x & y")
BinaryXor = Dispatcher("BinaryXor", "x ^ y")
BinaryInverse = Dispatcher("BinaryInverse", "~x")
# Dispatchers for special operations
Transpose = Dispatcher("Transpose", "x.T")
# Others
Log = Dispatcher("Log", 'log(x)')
Exp = Dispatcher("Exp", 'exp(x)')
Sin = Dispatcher("Sin", 'sin(x)')
Cos = Dispatcher("Cos", 'cos(x)')
Tan = Dispatcher("Tan", 'tan(x)')
############################
# PythonOperatorSupport
############################
class PythonOperatorSupport(object):
"""Support for built-in Python operators."""
# Common arithmetic operations
def __add__(self, x):
return Add(self, x)
def __radd__(self, x):
return Add(x, self)
def __sub__(self, x):
return Subtract(self, x)
def __rsub__(self, x):
return Subtract(x, self)
def __mul__(self, x):
return Multiply(self, x)
def __rmul__(self, x):
return Multiply(x, self)
def __div__(self, x):
return Divide(self, x)
def __rdiv__(self, x):
return Divide(x, self)
def __floordiv__(self, x):
return FloorDivide(self, x)
def __rfloordiv__(self, x):
return FloorDivide(x, self)
def __mod__(self, x):
return Modulo(self, x)
def __rmod__(self, x):
return Modulo(x, self)
def __pow__(self, x):
return Power(self, x)
def __rpow__(self, x):
return Power(x, self)
def __neg__(self):
return Negate(self)
def __abs__(self):
return Abs(self)
# Less common arithmetic operations
def __lshift__(self, x):
return LeftShift(self, x)
def __rlshift__(self, x):
return LeftShift(x, self)
def __rshift__(self, x):
return RightShift(self, x)
def __rrshift__(self, x):
return RightShift(x, self)
# Comparison operations
# def __eq__(self, x):
# return Equals(self, x)
# def __ne__(self, x):
# return NotEquals(self, x)
def __lt__(self, x):
return Less(self, x)
def __le__(self, x):
return LessOrEqual(self, x)
def __gt__(self, x):
return Greater(self, x)
def __ge__(self, x):
return GreaterOrEqual(self, x)
def __contains__(self, x):
return Contains(self, x)
# Binary operations
def __or__(self, x):
return BinaryOr(self, x)
def __ror__(self, x):
return BinaryOr(x, self)
def __and__(self, x):
return BinaryAnd(self, x)
def __rand__(self, x):
return BinaryAnd(x, self)
def __xor__(self, x):
return BinaryXor(self, x)
def __rxor__(self, x):
return BinaryXor(x, self)
def __invert__(self):
return BinaryInverse(self)
# Other operations
T = property(lambda self: Transpose(self))
norm = property(lambda self: Norm(self))
# Always nonzero
def __nonzero__(self):
return True
__all__ = globals().keys()
差异被折叠。
class GofError(Exception):
pass
class GofTypeError(GofError):
pass
class GofValueError(GofError):
pass
class PropagationError(GofError):
pass
差异被折叠。
from copy import copy
from op import Op
import result
import graph
import utils
from random import shuffle
__all__ = ['Feature',
'Listener',
'Constraint',
'Orderings',
'Tool',
# 'Preprocessor',
'EquivTool',
'InstanceFinder',
'PrintListener',
'ChangeListener',
# 'DestroyPreprocessor',
# 'DestroyHandler'
]
class Feature(object):
def __init__(self, env):
self.env = env
class Listener(Feature):
def on_import(self, op):
pass
def on_prune(self, op):
pass
def on_rewire(self, clients, r, new_r):
pass
class Constraint(Feature):
def validate(self):
return True
class Orderings(Feature):
def orderings(self):
return {}
class Tool(Feature):
def publish(self):
pass
# class Preprocessor(Feature):
# def transform(self, inputs, outputs):
# return inputs, outputs
# def __call__(self, inputs, outputs):
# return self.transform(inputs, outputs)
# class Optimization(object):
# def require(self):
# return []
# def apply(self, env):
# pass
# def __call__(self, env):
# return self.apply(env)
# Optimization
# * require <tool_class>*
# * apply
# Prog
# * __init__
# * inputs
# * outputs
# * listeners, constraints, orderings
# * dispatched by isinstance Listener, etc.
# * {tool_class: preferred_implementation, ...}
class EquivTool(Listener, Tool, dict):
def on_rewire(self, clients, r, new_r):
repl = self(new_r)
if repl is r:
self.ungroup(r, new_r)
elif repl is not new_r:
raise Exception("Improper use of EquivTool!")
else:
self.group(new_r, r)
def publish(self):
self.env.equiv = self
def group(self, main, *keys):
"Marks all the keys as having been replaced by the Result main."
keys = [key for key in keys if key is not main]
if self.has_key(main):
raise Exception("Only group results that have not been grouped before.")
for key in keys:
if self.has_key(key):
raise Exception("Only group results that have not been grouped before.")
if key is main:
continue
self.setdefault(key, main)
def ungroup(self, main, *keys):
"Undoes group(main, *keys)"
keys = [key for key in keys if key is not main]
for key in keys:
if self[key] is main:
del self[key]
def __call__(self, key):
"Returns the currently active replacement for the given key."
next = self.get(key, None)
while next:
key = next
next = self.get(next, None)
return key
class InstanceFinder(Listener, Tool, dict):
def __init__(self, env):
self.env = env
def all_bases(self, cls):
return utils.all_bases(cls, lambda cls: issubclass(cls, Op))
# return [cls for cls in utils.all_bases(cls) if issubclass(cls, Op)]
def on_import(self, op):
for base in self.all_bases(op.__class__):
self.setdefault(base, set()).add(op)
def on_prune(self, op):
for base in self.all_bases(op.__class__):
self[base].remove(op)
if not self[base]:
del self[base]
def __query__(self, cls):
all = [x for x in self.get(cls, [])]
shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
while all:
next = all.pop()
if next in self.env.ops():
yield next
def query(self, cls):
return self.__query__(cls)
def publish(self):
self.env.get_instances_of = self.query
class PrintListener(Listener):
def __init__(self, env, active = True):
self.env = env
self.active = active
if active:
print "-- initializing"
def on_import(self, op):
if self.active:
print "-- importing: %s" % graph.as_string(self.env.inputs, op.outputs)
def on_prune(self, op):
if self.active:
print "-- pruning: %s" % graph.as_string(self.env.inputs, op.outputs)
def on_rewire(self, clients, r, new_r):
if self.active:
if r.owner is None:
rg = id(r) #r.name
else:
rg = graph.as_string(self.env.inputs, r.owner.outputs)
if new_r.owner is None:
new_rg = id(new_r) #new_r.name
else:
new_rg = graph.as_string(self.env.inputs, new_r.owner.outputs)
print "-- moving from %s to %s" % (rg, new_rg)
class ChangeListener(Listener):
def __init__(self, env):
self.change = False
def on_import(self, op):
self.change = True
def on_prune(self, op):
self.change = True
def on_rewire(self, clients, r, new_r):
self.change = True
def __call__(self, value = "get"):
if value == "get":
return self.change
else:
self.change = value
# class SuperFinder(Listener, Tool, dict):
# def __init__(self, env, props):
# self.env = env
# self.props = props
# def on_import(self, op):
# for prop, value in self.props(op).items():
# self.setdefault(prop, {}).setdefault(value, set()).add(op)
# def on_prune(self, op):
# for prop, value in self.props(op).items():
# self[prop][value].remove(op)
# if len(self[prop][value]) == 0:
# del self[prop][value]
# if len(self[prop]) == 0:
# del self[prop]
# def __query__(self, order, template):
# all = []
# for prop, value in template.items():
# all += [x for x in self.get(prop, {}).get(value, set())]
# # If not None, the order option requires the order listener to be included in the env under the name 'order'
# if order == 'o->i':
# all.sort(lambda op1, op2: self.env.order[op1].__cmp__(self.env.order[op2]))
# elif order == 'i->o':
# all.sort(lambda op1, op2: self.env.order[op2].__cmp__(self.env.order[op1]))
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, **template):
# return self.__query__(None, template)
# def query_downstream(self, **template):
# return self.__query__('i->o', template)
# def query_upstream(self, **template):
# return self.__query__('o->i', template)
# def publish(self):
# self.env.query = self.query
from copy import copy
from result import Result, BrokenLink, BrokenLinkError
from op import Op
import utils
__all__ = ['inputs',
'results_and_orphans', 'results', 'orphans',
'ops',
'clone', 'clone_get_equiv',
'io_toposort',
'as_string',
'Graph']
def inputs(o, repair = False):
"""
o -> list of output Results
Returns the set of inputs necessary to compute the outputs in o
such that input.owner is None.
"""
results = set()
def seek(r):
if isinstance(r, BrokenLink):
raise BrokenLinkError
op = r.owner
if op is None:
results.add(r)
else:
for i in range(len(op.inputs)):
try:
seek(op.inputs[i])
except BrokenLinkError:
if repair:
op.refresh()
seek(op.inputs[i])
else:
raise
for output in o:
seek(output)
return results
def results_and_orphans(i, o):
"""
i -> list of input Results
o -> list of output Results
Returns the pair (results, orphans). The former is the set of
Results that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results = set(o)
results.update(i)
incomplete_paths = []
def helper(r, path):
if isinstance(r, BrokenLink):
raise BrokenLinkError
if r in i:
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
return results, orphans
def ops(i, o):
"""
i -> list of input Results
o -> list of output Results
Returns the set of ops that are contained within the subgraph
that lies between i and o, including the owners of the Results in
o and intermediary ops between i and o, but not the owners of the
Results in i.
"""
ops = set()
results, orphans = results_and_orphans(i, o)
for r in results:
if r not in i and r not in orphans:
ops.add(r.owner)
return ops
def results(i, o):
"""
i -> list of input Results
o -> list of output Results
Returns the set of Results that are involved in the subgraph
that lies between i and o. This includes i, o, orphans(i, o)
and all values of all intermediary steps from i to o.
"""
return results_and_orphans(i, o)[0]
def orphans(i, o):
"""
i -> list of input Results
o -> list of output Results
Returns the set of Results which one or more Results in o depend
on but are neither in i nor in the subgraph that lies between
i and o.
e.g. orphans([x], [(x+y).out]) => [y]
"""
return results_and_orphans(i, o)[1]
def clone(i, o):
"""
i -> list of input Results
o -> list of output Results
Copies the subgraph contained between i and o and returns the
outputs of that copy (corresponding to o). The input Results in
the list are _not_ copied and the new graph refers to the
originals.
"""
new_o, equiv = clone_get_equiv(i, o)
return new_o
def clone_get_equiv(i, o, copy_inputs = False):
"""
i -> list of input Results
o -> list of output Results
Returns (new_o, equiv) where new_o are the outputs of a copy of
the whole subgraph bounded by i and o and equiv is a dictionary
that maps the original ops and results found in the subgraph to
their copy (akin to deepcopy's memo). See clone for more details.
"""
d = {}
for op in ops(i, o):
d[op] = copy(op)
for old_op, op in d.items():
for old_output, output in zip(old_op.outputs, op.outputs):
d[old_output] = output
for i, input in enumerate(op.inputs):
owner = input.owner
if owner in d:
op._inputs[i] = d[owner].outputs[input._index]
return [[d[output] for output in o], d]
def io_toposort(i, o, orderings = {}):
"""
i -> list of input Results
o -> list of output Results
orderings -> {op: [requirements for op]} (defaults to {})
Returns an ordered list of Ops that belong in the subgraph between
i and o which respects the following constraints:
- all inputs in i are assumed to be already computed
- the Ops that compute an Op's inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
prereqs_d.setdefault(op, set()).update(set([input.owner for input in op.inputs if input.owner and input.owner in all]))
# prereqs_d[op] = set([input.owner for input in op.inputs if input.owner and input.owner in all])
return utils.toposort(prereqs_d)
def as_string(i, o):
"""
i -> list of input Results
o -> list of output Results
Returns a string representation of the subgraph between i and o. If the same
Op is used by several other ops, the first occurrence will be marked as
'*n -> description' and all subsequent occurrences will be marked as '*n',
where n is an id number (ids are attributed in an unspecified order and only
exist for viewing convenience).
"""
multi = set()
seen = set()
for op in ops(i, o):
for input in op.inputs:
op2 = input.owner
if input in i or op2 is None:
continue
if op2 in seen:
multi.add(op2)
else:
seen.add(input.owner)
multi = [x for x in multi]
done = set()
def multi_index(x):
try:
return multi.index(x) + 1
except:
return 999
def describe(x, first = False):
if isinstance(x, Result):
done.add(x)
if x.owner is not None and x not in i:
op = x.owner
idx = op.outputs.index(x)
if idx:
s = describe(op, first) + "." + str(idx)
else:
s = describe(op, first)
return s
else:
return str(id(x))
elif isinstance(x, Op):
if x in done:
return "*%i" % multi_index(x)
else:
done.add(x)
if not first and hasattr(x, 'name') and x.name is not None:
return x.name
s = x.__class__.__name__ + "(" + ", ".join([describe(v) for v in x.inputs]) + ")"
if x in multi:
return "*%i -> %s" % (multi_index(x), s)
else:
return s
else:
raise TypeError("Cannot print type: %s" % x.__class__)
return "[" + ", ".join([describe(x, True) for x in o]) + "]"
# Op.__str__ = lambda self: as_string(inputs(self.outputs), self.outputs)[1:-1]
# Result.__str__ = lambda self: as_string(inputs([self]), [self])[1:-1]
class Graph:
def __init__(self, inputs, outputs):
self.inputs = inputs
self.outputs = outputs
def ops(self):
return ops(self.inputs, self.outputs)
def values(self):
return values(self.inputs, self.outputs)
def orphans(self):
return orphans(self.inputs, self.outputs)
def io_toposort(self):
return io_toposort(self.inputs, self.outputs)
def toposort(self):
return self.io_toposort()
def clone(self):
o = clone(self.inputs, self.outputs)
return Graph(self.inputs, o)
def __str__(self):
return as_string(self.inputs, self.outputs)
差异被折叠。
def perform_linker(env, target = None):
order = env.toposort()
thunks = [op._perform for op in order]
def ret():
for thunk in thunks:
thunk()
if not target:
return ret
else:
raise NotImplementedError("Cannot write thunk representation to a file.")
def cthunk_linker(env):
order = env.toposort()
thunks = []
cstreak = []
def append_cstreak():
if cstreak:
thunks.append(cutils.create_cthunk_loop(*cstreak))
cstreak = []
def ret():
for thunk in thunks:
thunk()
for op in order:
if hasattr(op, 'cthunk'):
cstreak.append(op.cthunk())
else:
append_cstreak()
thunks.append(op.perform)
if len(thunks) == 1:
return thunks[0]
else:
return ret
差异被折叠。
差异被折叠。
差异被折叠。
# import compile
import env
import link
from features import EquivTool
class Prog:
def __init__(self, inputs, outputs, optimizer, linker, features = []):
self.optimizer = optimizer
self.linker = linker
features = set(features)
features.add(EquivTool)
self.env = env.Env(inputs, outputs, features) #, False)
self.optimizer.optimize(self.env)
self.perform = self.linker(self.env)
self.outputs = outputs
# def __optimize__(self):
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
def equiv(self, r):
return self.env.equiv(r)
def __getitem__(self, r):
return self.equiv(r)
def __setitem__(self, r, value):
if isinstance(r, tuple):
for a, b in zip(r, value):
self.__setitem__(a, b)
else:
self.equiv(r).set_value(value)
def __call__(self, *args):
self.perform()
for output in self.outputs:
output.set_value(self[output])
return self.outputs
# return [output for output in self.env.outputs]
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
"""
Contains the Result class, which is the base interface for a
value that is the input or the output of an Op.
"""
from err import GofError
__all__ = ['Result', 'BrokenLink', 'BrokenLinkError']
class BrokenLink:
"""
This is placed as the owner of a Result that was replaced by
another Result.
"""
__slots__ = ['owner', 'index']
def __init__(self, owner, index):
self.owner = owner
self.index = index
def __nonzero__(self):
return False
class BrokenLinkError(GofError):
"""
"""
pass
############################
# Result
############################
class Result(object):
"""
The Result class represents a datum for use in a graph of Ops. It
has two slots:
- owner: represents the Op which computes this Result. It is
assumed to be an instance of Op. If owner raises an
AttributeError, the Result is assumed to be an input.
- index: the index this Result holds in its owner's outputs.
Result has no __init__ or __new__ routine. It is the Op's
responsibility to set the owner field of its results.
The Result class is abstract. It must be subclassed to support the
types of data needed for computation.
A Result instance should be immutable: indeed, if some aspect of a
Result is changed, operations that use it might suddenly become
invalid. Instead, a new Result instance should be instanciated
with the correct properties and the invalidate method should be
called on the Result which is replaced (this will make its owner a
BrokenLink instance, which behaves like False in conditional
expressions).
"""
__slots__ = ['_owner', '_index']
def get_owner(self):
if not hasattr(self, '_owner'):
self._owner = None
return self._owner
owner = property(get_owner, doc = "The Op of which this Result is an output or None if there is no such Op.")
def set_owner(self, owner, index):
if self.owner is not None:
if self.owner is not owner:
raise ValueError("Result %s already has an owner." % self)
elif self.index != index:
raise ValueError("Result %s was already mapped to a different index." % self)
self._owner = owner
self._index = index
def invalidate(self):
if self.owner is None:
raise Exception("Cannot invalidate a Result instance with no owner.")
elif not isinstance(self.owner, BrokenLink):
self._owner = BrokenLink(self._owner, self._index)
del self._index
def revalidate(self):
if isinstance(self.owner, BrokenLink):
owner, index = self._owner.owner, self._owner.index
self._owner = owner
self._index = index
def set_value(self, value):
"""
Copies the provided value in this Result. It is not required to
implement this method.
"""
raise NotImplementedError("This Result does not support set_value.")
# def extract(self):
# """
# Returns a representation of this datum for use in Op.impl.
# Successive calls to extract should always return the same object.
# """
# raise NotImplementedError
# def sync(self):
# """
# After calling Op.impl, synchronizes the Result instance with the
# new contents of the storage. This might usually not be necessary.
# """
# raise NotImplementedError
# def c_libs(self):
# """
# Returns a list of libraries that must be included to work with
# this Result.
# """
# raise NotImplementedError
# def c_imports(self):
# """
# Returns a list of strings representing headers to import when
# building a C interface that uses this Result.
# """
# raise NotImplementedError
# def c_declare(self):
# """
# Returns code which declares and initializes a C variable in
# which this Result can be held.
# """
# raise NotImplementedError
# def pyo_to_c(self):
# raise NotImplementedError
# def c_to_pyo(self):
# raise NotImplementedError
############################
# Utilities
############################
# class SelfContainedResult(Result):
# """
# This represents a Result which acts as its own data container. It
# is recommended to subclass this if you wish to be able to use the
# Result in normal computations as well as working with a graph
# representation.
# """
# # def extract(self):
# # """Returns self."""
# # return self
# # def sync(self):
# # """Does nothing."""
# # pass
# class HolderResult(Result):
# """
# HolderResult adds a 'data' slot which is meant to contain the
# object used by the Op implementation. It is recommended to subclass
# this if you want to be able to use the exact same object at
# different points in a computation.
# """
# __slots__ = ['data']
# # def extract(self):
# # """Returns self.data."""
# # return self.data
# # def sync(self):
# # """
# # Does nothing. Override if you have additional fields or
# # functionality in your subclass which need to be computed from
# # the data.
# # """
# # pass
差异被折叠。
# import op
# import result
class OmegaError(Exception):
pass
def all_bases(cls, accept):
rval = set([cls])
for base in cls.__bases__:
rval.update(all_bases(base, accept))
return [cls for cls in rval if accept(cls)]
def all_bases_collect(cls, raw_name):
rval = set()
name = "__%s__" % raw_name
if name in cls.__dict__: # don't use hasattr
rval.add(getattr(cls, name))
cut = "__%s_override__" % raw_name
if not cls.__dict__.get(cut, False):
for base in cls.__bases__:
rval.update(all_bases_collect(base, raw_name))
return rval
def uniq_features(_features, *_rest):
features = [x for x in _features]
for other in _rest:
features += [x for x in other]
res = []
while features:
feature = features.pop()
for feature2 in features:
if issubclass(feature2, feature):
break
else:
res.append(feature)
return res
def partial(func, *args, **keywords):
def newfunc(*fargs, **fkeywords):
newkeywords = keywords.copy()
newkeywords.update(fkeywords)
return func(*(args + fargs), **newkeywords)
newfunc.func = func
newfunc.args = args
newfunc.keywords = keywords
return newfunc
class ClsInit(type):
"""Class initializer for Op subclasses"""
def __init__(cls, name, bases, dct):
"""Validate and initialize the Op subclass 'cls'
This function:
- changes class attributes input_names and output_names to be lists if they are single strings.
"""
type.__init__(cls, name, bases, dct)
cls.__clsinit__(cls, name, bases, dct)
def toposort(prereqs_d):
"""
Sorts prereqs_d.keys() topologically. prereqs_d[x] contains all the elements
that must come before x in the ordering.
"""
# all1 = set(prereqs_d.keys())
# all2 = set()
# for x, y in prereqs_d.items():
# all2.update(y)
# print all1.difference(all2)
seq = []
done = set()
postreqs_d = {}
for x, prereqs in prereqs_d.items():
for prereq in prereqs:
postreqs_d.setdefault(prereq, set()).add(x)
next = set(k for k in prereqs_d if not prereqs_d[k])
while next:
bases = next
next = set()
for x in bases:
done.add(x)
seq.append(x)
for x in bases:
for postreq in postreqs_d.get(x, []):
if not prereqs_d[postreq].difference(done):
next.add(postreq)
if len(prereqs_d) != len(seq):
raise Exception("Cannot sort topologically: there might be cycles, " + \
"prereqs_d does not have a key for each element or " + \
"some orderings contain invalid elements.")
return seq
# def schedule(**kwargs):
# after = kwargs.get('after', [])
# if not isinstance(after, (list, tuple)):
# after = [after]
# before = kwargs.get('before', [])
# if not isinstance(before, (list, tuple)):
# before = [before]
# def decorate(fn):
# name = fn.__name__
# fn.prereqs_d = {}
# for postreq in after:
# prereqs_d[postreq] = name
# for prereq in before:
# prereqs_d[name] = prereq
# return fn
# return decorate
# def after(*others):
# return schedule(after = others)
# def before(*others):
# return schedule(before = others)
# class TopoList(list):
# def add(self, item, **kwargs):
# after = kwargs.get('after', [])
# if not isinstance(after, (list, tuple)):
# after = [after]
# before = kwargs.get('before', [])
# if not isinstance(before, (list, tuple)):
# before = [before]
class Keyword:
def __init__(self, name, nonzero=True):
self.name = name
self.nonzero = nonzero
def __nonzero__(self):
return self.nonzero
def __str__(self):
return "<%s>" % self.name
def __repr__(self):
return "<%s>" % self.name
ABORT = Keyword("ABORT", False)
RETRY = Keyword("RETRY", False)
FAILURE = Keyword("FAILURE", False)
simple_types = (int, float, str, bool, None.__class__, Keyword)
ANY_TYPE = Keyword("ANY_TYPE")
FALL_THROUGH = Keyword("FALL_THROUGH")
def comm_guard(type1, type2):
def wrap(f):
old_f = f.func_globals[f.__name__]
def new_f(arg1, arg2, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)) \
and (type2 is ANY_TYPE or isinstance(arg2, type2)):
pass
elif (type1 is ANY_TYPE or isinstance(arg2, type1)) \
and (type2 is ANY_TYPE or isinstance(arg1, type2)):
arg1, arg2 = arg2, arg1
else:
try:
return old_f(arg1, arg2, *rest)
except:
raise
try:
result = f(arg1, arg2, *rest)
except:
raise
if result is FALL_THROUGH:
try:
return old_f(arg1, arg2, *rest)
except:
raise
else:
return result
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
elif isinstance(type, (tuple, list)):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = str(old_f.__doc__) + "\n" + ", ".join([typename(type) for type in (type1, type2)]) + "\n" + str(f.__doc__ or "")
return new_f
return wrap
def type_guard(type1):
def wrap(f):
old_f = f.func_globals[f.__name__]
def new_f(arg1, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)):
result = f(arg1, *rest)
if result is FALL_THROUGH:
return old_f(arg1, *rest)
else:
return result
else:
return old_f(arg1, *rest)
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
elif isinstance(type, (tuple, list)):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = str(old_f.__doc__) + "\n" + ", ".join([typename(type) for type in (type1,)]) + "\n" + str(f.__doc__ or "")
return new_f
return wrap
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论