提交 d0c30ce2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

blah

......@@ -2,8 +2,14 @@
import gof
from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode, pop_mode, UNCOMPUTED, UNDEFINED, PythonR
import type_spec
import cutils
import numpy
import weakref
import inspect
import md5
from scipy import weave
from copy import copy as pycopy
......@@ -16,37 +22,12 @@ def build(f, *args, **kwargs):
pop_mode()
return r
class Proxy(object):
__slots__ = ['_obj']
def __init__(self, obj = None):
self._obj = obj
def __getattribute__(self, attr):
if attr in ['__class__', '_obj']:
return object.__getattribute__(self, attr)
else:
return getattr(object.__getattribute__(self, '_obj'), attr)
def __setattr__(self, attr, value):
if attr in ['_obj']:
object.__setattr__(self, attr, value)
else:
setattr(self._obj, attr, value)
def __delattr__(self, attr):
delattr(self._obj, attr)
def as_string(*rs):
s = gof.graph.as_string(gof.graph.inputs(rs), rs)
if len(rs) == 1:
return s[1:-1]
else:
return s
# return str(gof.Env(gof.graph.inputs([r]), [r]))[1:-1]
def print_graph(*rs):
print as_string(*rs)
......@@ -72,8 +53,6 @@ def wrap(x):
return x
elif isinstance(x, omega_op):
return x.out
elif isinstance(x, Proxy):
return wrap(x._obj)
else:
return literal(x)
......@@ -85,13 +64,6 @@ def _hashable(x):
return False
def _literal_hashable(x):
# try:
# present = x in literals_db
# hashable = True
# except TypeError: # x is unhashable
# present = False
# hashable = False
if x in literals_db:
return literals_db[x]
else:
......@@ -99,20 +71,6 @@ def _literal_hashable(x):
r.constant = True
literals_db[x] = r
return r
# elif isinstance(x, numpy.ndarray):
# ret = NumpyR(x, constant = True)
# elif isinstance(x, (int, float)):
# ret = NumpyR(numpy.array(x), constant = True)
# elif isinstance(x, gof.Result):
# raise TypeError("%s is already a result." % x)
# else:
# return PythonR(x, constant = True)
# if hashable:
# literals_db[x] = ret
# return ret
def _literal_unhashable(x):
idx = id(x)
......@@ -124,7 +82,6 @@ def _literal_unhashable(x):
literals_id_db[idx] = r
return r
def literal(x):
if _hashable(x):
return _literal_hashable(x)
......@@ -135,21 +92,91 @@ def literal(x):
inplace = gof.Destroyer
view = gof.Viewer
def cgetspecs(names, vals, converters):
d = {}
for name, value in zip(names, vals):
d[name] = value.data
specs = weave.ext_tools.assign_variable_types(names, d, type_converters = converters) #, auto_downcast = 0)
return d, specs
def cgen(name, behavior, names, vals, converters = None):
if not converters:
converters = type_spec.default
for converter in converters:
assert isinstance(converter, type_spec.omega_type_converter_extension)
d, specs = cgetspecs(names, vals, converters)
template = {}
template['name'] = name
template['code'] = behavior
template['members'] = "".join([spec.struct_members_code() for spec in specs])
template['support'] = "".join([spec.struct_support_code() for spec in specs])
template['typedefs'] = "".join([spec.struct_typedefs() for spec in specs])
template['incref'] = "".join(["Py_INCREF(py_%s);\n" % spec.name for spec in specs if spec.use_ref_count])
template['decref'] = "".join(["Py_DECREF(py_%s);\n" % spec.name for spec in specs if spec.use_ref_count])
template['struct_contents'] = """
%(typedefs)s
%(members)s
%(support)s
void init(void) {
%(incref)s
}
void cleanup(void) {
%(decref)s
}
int execute(void) {
%(code)s
return 0;
}
""" % template
template['md5'] = md5.md5(template['struct_contents']).hexdigest()
template['struct_name'] = "_omega_%(name)s_%(md5)s" % template
struct = "struct %(struct_name)s { %(struct_contents)s\n};" % template
static = """
int %(struct_name)s_executor(%(struct_name)s* self) {
return self->execute();
}
void %(struct_name)s_destructor(void* executor, void* self) {
((%(struct_name)s*)self)->cleanup();
free(self);
}
""" % template
code = "%(struct_name)s* __STRUCT_P = new %(struct_name)s();\n" % template
code += "".join([spec.struct_import_code() for spec in specs])
code += "__STRUCT_P->init();\n"
code += "return_val = PyCObject_FromVoidPtrAndDesc((void*)(&%(struct_name)s_executor), __STRUCT_P, %(struct_name)s_destructor);\n" % template
return d, code, struct + static, converters
def make_static(cls, fname):
f = getattr(cls, fname)
if hasattr(f, 'im_func'):
f = f.im_func
setattr(cls, fname, staticmethod(f))
class omega_op(gof.PythonOp):
forbid_broadcast = False
@staticmethod
def __clsinit__(cls, name, bases, dct):
# make grad a static method
grad = cls.grad
if hasattr(grad, 'im_func'):
grad = grad.im_func
cls.grad = staticmethod(grad)
# # adjust impl
# if cls.forbid_broadcast:
# cls.impl = assert_same_shapes(cls.impl)
for fname in ['grad', 'c_impl', 'c_alloc']:
make_static(cls, fname)
# make impl a static method
gof.PythonOp.__clsinit__(cls, name, bases, dct)
......@@ -171,6 +198,295 @@ class omega_op(gof.PythonOp):
def grad(*args):
return UNDEFINED
def c_code(self, converters = None):
(inames, onames), behavior = self._c_impl()
return cgen(self.__class__.__name__, behavior, inames + onames, self.inputs + self.outputs, converters)
def _c_alloc(self):
self.c_alloc(self.inputs, self.outputs)
def c_alloc(inputs, outputs):
raise NotImplementedError()
def _c_impl(self):
(inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
return (inames, onames), self.c_impl(self.inputs, self.outputs)
def c_impl(inputs, outputs):
raise NotImplementedError()
def c_thunk(self):
self._c_alloc()
d, code, struct, converters = self.c_code()
thunk = weave.inline(code, d.keys(), local_dict = d, global_dict = {}, support_code = struct, type_converters = converters)
return thunk
def c_perform(self):
thunk = self.c_thunk()
cutils.run_cthunk(thunk)
def elemwise_wrap_old(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars):
return """
%(beforeloop)s
for (int i = 0; i < N_%(v1)s[0]; i++) {
for (int j = 0; j < N_%(v1)s[1]; j++) {
%(idefs)s
%(odefs)s
%(inloop)s
}
}
%(afterloop)s
""" % dict(v1 = (loop_vars + writable_loop_vars)[0],
idefs = "\n".join(["_%s_dtype %s = _%s2(i, j);" % (loop_var, loop_var, loop_var.upper())
for loop_var in loop_vars]),
odefs = "\n".join(["_%s_dtype& %s = _%s2(i, j);" % (writable_loop_var, writable_loop_var, writable_loop_var.upper())
for writable_loop_var in writable_loop_vars]),
beforeloop = beforeloop,
inloop = inloop,
afterloop = afterloop)
def elemwise_loopcode(loopcode, init_template, next_template, acquire_template, cleanup_template, loop_vars, writable_loop_vars, aliases):
all_loop_vars = loop_vars + writable_loop_vars
template = dict(
init = "".join([init_template % dict(loop_var = loop_var) for loop_var in all_loop_vars]),
next = "".join([next_template % dict(loop_var = loop_var) for loop_var in all_loop_vars]),
cleanup = "".join([cleanup_template % dict(loop_var = loop_var) for loop_var in all_loop_vars]),
idefs = "".join([("_%(loop_var)s_dtype %(loop_var)s = " + acquire_template + ";\n")
% dict(loop_var = loop_var) for loop_var in loop_vars]),
odefs = "".join([("_%(loop_var)s_dtype& %(loop_var)s = " + acquire_template + ";\n")
% dict(loop_var = loop_var) for loop_var in writable_loop_vars]),
aliasdefs = "".join(["_%(v1)s_dtype %(v1)s = %(v2)s;\n" % dict(v1=v1, v2=v2)
for v1, v2 in aliases.items()]),
loopcode = loopcode
)
code = """
%(init)s
while (__elemwise_size--) {
%(idefs)s
%(odefs)s
%(aliasdefs)s
%(loopcode)s
%(next)s
}
%(cleanup)s
""" % template
return code
def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars, aliases):
general_init = "PyArrayIterObject* _%(loop_var)s_iter = (PyArrayIterObject*)PyArray_IterNew((PyObject*)_%(loop_var)s_array);\n"
# "if (_%(loop_var)s_iter == NULL) {\n" \
# " PyErr_SetString(PyExc_ValueError, \"Could not make an iterator over variable %(loop_var)s.\");\n" \
# " return 1;\n" \
# "}\n"
general_next = "PyArray_ITER_NEXT(_%(loop_var)s_iter);\n"
general_acquire = "*((_%(loop_var)s_dtype*)_%(loop_var)s_iter->dataptr)";
general_cleanup = "if (_%(loop_var)s_iter) Py_DECREF(_%(loop_var)s_iter);\n";
contiguous_init = "_%(loop_var)s_dtype* _%(loop_var)s_iter = (_%(loop_var)s_dtype*)PyArray_DATA(_%(loop_var)s_array);\n"
contiguous_next = "_%(loop_var)s_iter++;\n"
contiguous_acquire = "*_%(loop_var)s_iter"
contiguous_cleanup = ""
all_loop_vars = loop_vars + writable_loop_vars
template = dict(
v1 = (loop_vars + writable_loop_vars)[0],
beforeloop = beforeloop,
general_loop = elemwise_loopcode(
inloop,
general_init, general_next, general_acquire, general_cleanup,
loop_vars, writable_loop_vars, aliases),
contiguous_loop = elemwise_loopcode(
inloop,
contiguous_init, contiguous_next, contiguous_acquire, contiguous_cleanup,
loop_vars, writable_loop_vars, aliases),
contiguity_check = "".join(["all_c_contiguous &= PyArray_ISCARRAY(_%(loop_var)s_array);\n" \
"all_f_contiguous &= PyArray_ISFARRAY(_%(loop_var)s_array);\n" \
% dict(loop_var = loop_var)
for loop_var in all_loop_vars]),
afterloop = afterloop)
code = """
npy_intp __elemwise_size = PyArray_SIZE(_%(v1)s_array);
%(beforeloop)s
bool all_c_contiguous = 1;
bool all_f_contiguous = 1;
%(contiguity_check)s
if (all_c_contiguous || all_f_contiguous) {
%(contiguous_loop)s
}
else {
%(general_loop)s
}
%(afterloop)s
""" % template
print code
return code
def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype)
for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype)
return z.dtype
class elemwise(omega_op):
@staticmethod
def __clsinit__(cls, name, bases, dct):
for fname in ['c_init', 'c_foreach', 'c_finalize']:
make_static(cls, fname)
# make impl, grad, etc. static methods
omega_op.__clsinit__(cls, name, bases, dct)
def _c_alloc(self):
if isinstance(self, inplace):
dmap = self.destroy_map()
else:
dmap = {}
try:
return self.c_alloc(self.inputs, self.outputs)
except NotImplementedError:
(inames, onames), _1, _2, _3 = inspect.getargspec(self.c_foreach)
for oname in onames:
if oname.startswith("_"):
raise Exception("cannot infer an allocation policy automatically for variable " \
"%s because it is not part of the elementwise loop - "\
"please override the c_alloc method" % oname[1:])
shape, dtype = None, None
for iname, input in zip(inames, self.inputs):
if not iname.startswith("_"):
shape = input.data
if shape is None:
raise Exception("cannot infer an allocation policy automatically for output variables " \
"because there is no input variable in the loop from which to get the shape")
dtype = upcast(*[input.data.dtype
for iname, input in zip(inames, self.inputs)
if isinstance(input.data, numpy.ndarray)])
for output in self.outputs:
inplace_inputs = dmap.get(output, [])
if inplace_inputs:
assert len(inplace_inputs) == 1
output.data = inplace_inputs[0].data
else:
output.data = numpy.ndarray(shape, dtype)
def _c_init(self):
(inames, onames), _1, _2, _3 = inspect.getargspec(self.c_init)
return (inames, onames), self.c_init(self.inputs, self.outputs)
def c_init(inputs, outputs):
return ""
def _c_foreach(self):
(inames, onames), _1, _2, _3 = inspect.getargspec(self.c_foreach)
return (inames, onames), self.c_foreach(self.inputs, self.outputs)
def c_foreach(inputs, outputs):
return ""
def _c_finalize(self):
(inames, onames), _1, _2, _3 = inspect.getargspec(self.c_finalize)
return (inames, onames), self.c_finalize(self.inputs, self.outputs)
def c_finalize(inputs, outputs):
return ""
def c_code(self, converters = None, elemwise_wrap = elemwise_wrap):
def mangle(name):
if name.startswith("_"):
return name
else:
return "_" + name
try:
self._c_impl()
raise Exception("c_impl is not used by elemwise ops - define behavior in c_foreach instead")
except NotImplementedError:
pass
spec_b, before = self._c_init()
spec_d, during = self._c_foreach()
spec_a, after = self._c_finalize()
# Sanity check - apart from loop vars, variables are shared in the before/during/after parts
if before and spec_b != spec_d:
raise Exception("The input signature of c_init differs from the input signature of c_foreach.")
if after and spec_a != spec_d:
raise Exception("The input signature of c_finalize differs from the input signature of c_foreach.")
(inames, onames) = spec_d
aliases = {}
if isinstance(self, inplace):
dmap = self.destroy_map()
for oname, output in zip(onames, self.outputs):
if not oname.startswith("_"):
for input in dmap.get(output, []):
aliases[inames[self.inputs.index(input)]] = oname
behavior = elemwise_wrap(before, during, after,
[iname for iname in inames if not iname.startswith("_") and not iname in aliases],
[oname for oname in onames if not oname.startswith("_")],
aliases)
inames = [mangle(name) for name in inames]
onames = [mangle(name) for name in onames]
return cgen(self.__class__.__name__, behavior, inames + onames, self.inputs + self.outputs, converters)
@classmethod
def inplace_version(cls, dmap = {0: 0}):
(inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_foreach)
for i, oname in enumerate(onames):
if i in dmap:
assert not oname.startswith("_")
class C(cls, inplace):
def destroy_map(self):
ret = cls.destroy_map()
for output, input in self.dmap.items():
ret[self.outputs.index(output)] = [self.inputs.index(input)]
return ret
def _impl(self):
if self.impl is not cls.impl:
# If the user sets his own inplace operation, we use it
return cls._impl(self)
else:
res = cls._impl(self)
if isinstance(res, gof.Result):
res = [res]
else:
res = copy(res)
for output, input in dmap.items():
# The default implementation returned a copy, so we just
# overwrite the original input with the contents of that copy
# This is not meant to be efficient, only correct.
a = self.inputs[input].data
a[:] = res[output]
res[output] = a
if len(res) == 1:
return res[0]
else:
return res
if dmap == {0:0}:
C.__name__ = cls.__name__ + "_inplace" % dmap
else:
C.__name__ = cls.__name__ + "_inplace%s" % dmap
return C
def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
def f(x, y):
......@@ -252,7 +568,7 @@ def assert_same_shapes(impl):
return ret
# Wrapper to ensure that the last input to impl is a scalar
def tensor_scalar_op(impl):
def tensor_scalar_impl(impl):
def ret(x, a):
if a.shape:
raise ValueError("The second argument to %s must be a scalar." % impl)
......@@ -260,70 +576,123 @@ def tensor_scalar_op(impl):
return ret
# @omega_op
# def add((x, y), (z, )):
# def grad(gz):
# return gz
# def c_alloc():
# return numpy.ndarray(x.shape, dtype = x.dtype)
# c_impl = """
# for (int i = 0; i < z.ncols; i++) {
# for (int j = 0; j < z.nrows; j++) {
# z(i, j) = x(i, j) + y(i, j);
# }
# }
# """
## Addition ##
class proto_add_elemwise(omega_op):
class add_elemwise(elemwise):
impl = assert_same_shapes(numpy.ndarray.__add__)
def grad(x, y, gz):
return gz
def c_foreach((x, y), (z, )):
return "z = x + y;"
class add_elemwise(proto_add_elemwise):
impl = assert_same_shapes(numpy.ndarray.__add__)
iadd_elemwise = add_elemwise.inplace_version()
iadd_elemwise.impl = assert_same_shapes(numpy.ndarray.__iadd__)
# class proto_add_elemwise(omega_op):
# def grad(x, y, gz):
# return gz
# class add_elemwise(proto_add_elemwise):
# impl = assert_same_shapes(numpy.ndarray.__add__)
# class iadd_elemwise(proto_add_elemwise, inplace):
# impl = assert_same_shapes(numpy.ndarray.__iadd__)
class iadd_elemwise(proto_add_elemwise, inplace):
impl = assert_same_shapes(numpy.ndarray.__iadd__)
class tensor_scalar_op(elemwise):
def c_init((x, _a), (z, )):
return "_a_dtype a = _a[0];"
def _c_foreach(self):
return (('x', '_a'), ('z', )), "z = %s;" % self.c_operation
class proto_add_scalar(omega_op):
class add_scalar(tensor_scalar_op):
impl = tensor_scalar_impl(numpy.ndarray.__add__)
def grad(x, a, gz):
return gz, sum(gz)
c_expr = "x + a"
class add_scalar(proto_add_scalar):
impl = tensor_scalar_op(numpy.ndarray.__add__)
# def c_impl(x, s, z):
# """
# if (*__z == NULL) {
# *__z = new ndarray
# }
# ndarray& z = **__z
# """
# return """
# z.resize_like(x);
# for (int i = 0; i < z.size(); i++) {
# z[i] = x[i] * s;
# }
# return z;
# """
class iadd_scalar(proto_add_scalar, inplace):
impl = tensor_scalar_op(numpy.ndarray.__iadd__)
class proto_twice(omega_op):
iadd_scalar = add_scalar.inplace_version()
iadd_scalar.impl = tensor_scalar_impl(numpy.ndarray.__iadd__)
# class proto_add_scalar(omega_op):
# def grad(x, a, gz):
# return gz, sum(gz)
# class add_scalar(proto_add_scalar):
# impl = tensor_scalar_impl(numpy.ndarray.__add__)
# class iadd_scalar(proto_add_scalar, inplace):
# impl = tensor_scalar_impl(numpy.ndarray.__iadd__)
class twice(elemwise):
def grad(x, gz):
return scale(gz, 2.0)
class twice(proto_twice):
def impl(x):
return x + x
def c_foreach((x, ), (z, )):
"z = x + x;"
class itwice(proto_twice, inplace):
def impl(x):
x += x
return x
itwice = twice.inplace_version()
# class proto_twice(omega_op):
# def grad(x, gz):
# return scale(gz, 2.0)
# class twice(proto_twice):
# def impl(x):
# return x + x
# class itwice(proto_twice, inplace):
# def impl(x):
# x += x
# return x
## Subtraction ##
class proto_sub_elemwise(omega_op):
class sub_elemwise(elemwise):
impl = assert_same_shapes(numpy.ndarray.__sub__)
def grad(x, y, gz):
return gz, -gz
def c_foreach((x, y), (z, )):
return "z = x - y;"
class sub_elemwise(proto_sub_elemwise):
impl = assert_same_shapes(numpy.ndarray.__sub__)
isub_elemwise = sub_elemwise.inplace_version()
isub_elemwise.impl = assert_same_shapes(numpy.ndarray.__isub__)
# class proto_sub_elemwise(omega_op):
# def grad(x, y, gz):
# return gz, -gz
class isub_elemwise(proto_sub_elemwise, inplace):
impl = assert_same_shapes(numpy.ndarray.__isub__)
# class sub_elemwise(proto_sub_elemwise):
# impl = assert_same_shapes(numpy.ndarray.__sub__)
# class isub_elemwise(proto_sub_elemwise, inplace):
# impl = assert_same_shapes(numpy.ndarray.__isub__)
def sub_scalar_r(x, a):
return add_scalar(x, -a)
......@@ -337,67 +706,127 @@ def isub_scalar_r(x, a):
## Element-wise multiplication ##
class proto_mul_elemwise(omega_op):
class mul_elemwise(elemwise):
impl = assert_same_shapes(numpy.ndarray.__mul__)
def grad(x, y, gz):
return mul(y, gz), mul(x, gz)
def c_foreach((x, y), (z, )):
return "z = x * y;"
class mul_elemwise(proto_mul_elemwise):
impl = assert_same_shapes(numpy.ndarray.__mul__)
imul_elemwise = mul_elemwise.inplace_version()
imul_elemwise.impl = assert_same_shapes(numpy.ndarray.__imul__)
# class proto_mul_elemwise(omega_op):
# def grad(x, y, gz):
# return mul(y, gz), mul(x, gz)
class imul_elemwise(proto_mul_elemwise, inplace):
impl = assert_same_shapes(numpy.ndarray.__imul__)
# class mul_elemwise(proto_mul_elemwise):
# impl = assert_same_shapes(numpy.ndarray.__mul__)
# class imul_elemwise(proto_mul_elemwise, inplace):
# impl = assert_same_shapes(numpy.ndarray.__imul__)
class proto_scale(omega_op):
class scale(tensor_scalar_op):
impl = tensor_scalar_impl(numpy.ndarray.__mul__)
def grad(x, a, gz):
return scale(a, gz), sum(mul_elemwise(x, gz))
c_expr = "x * a"
iscale = scale.inplace_version()
iscale.impl = tensor_scalar_impl(numpy.ndarray.__imul__)
class scale(proto_scale):
impl = tensor_scalar_op(numpy.ndarray.__mul__)
# class proto_scale(omega_op):
# def grad(x, a, gz):
# return scale(a, gz), sum(mul_elemwise(x, gz))
class iscale(proto_scale, inplace):
impl = tensor_scalar_op(numpy.ndarray.__imul__)
# class scale(proto_scale):
# impl = tensor_scalar_impl(numpy.ndarray.__mul__)
# class iscale(proto_scale, inplace):
# impl = tensor_scalar_impl(numpy.ndarray.__imul__)
class proto_sqr(omega_op):
class sqr(elemwise):
def impl(x):
return x * x
def grad(x, gz):
return scale(mul_elemwise(x, gz), 2.0)
def c_foreach((x, ), (z, )):
"z = x * x;"
class sqr(proto_sqr):
impl = lambda x: numpy.multiply(x, x)
isqr = sqr.inplace_version()
isqr.impl = lambda x: x.__imul__(x)
class isqr(proto_sqr, inplace):
impl = lambda x: x.__imul__(x)
# class proto_sqr(omega_op):
# def grad(x, gz):
# return scale(mul_elemwise(x, gz), 2.0)
class proto_sqrt(omega_op):
# class sqr(proto_sqr):
# impl = lambda x: numpy.multiply(x, x)
# class isqr(proto_sqr, inplace):
# impl = lambda x: x.__imul__(x)
class sqrt(elemwise):
impl = numpy.sqrt
def grad(x, gz):
return scale(div(gz, sqrt(x)), 0.5)
def c_foreach((x, ), (z, )):
"z = pow(x, 0.5);"
class sqrt(proto_sqrt):
impl = numpy.sqrt
isqrt = sqrt.inplace_version()
isqrt.impl = lambda x: x.__ipow__(0.5)
# class proto_sqrt(omega_op):
# def grad(x, gz):
# return scale(div(gz, sqrt(x)), 0.5)
class isqrt(proto_sqrt, inplace):
impl = lambda x: x.__ipow__(0.5)
# class sqrt(proto_sqrt):
# impl = numpy.sqrt
# class isqrt(proto_sqrt, inplace):
# impl = lambda x: x.__ipow__(0.5)
## Exponentiation ##
class exp(omega_op):
class exp(elemwise):
impl = numpy.exp
def c_foreach((x, ), (z, )):
return "z = exp(x);"
# class exp(omega_op):
# impl = numpy.exp
## Element-wise division ##
class proto_div_elemwise(omega_op):
class div_elemwise(elemwise):
impl = assert_same_shapes(numpy.ndarray.__div__)
def grad(x, y, gz):
return div(gz, y), -div(mul(x, gz), sqr(y))
def c_foreach((x, y), (z, )):
return "z = x / y;"
class div_elemwise(proto_div_elemwise):
impl = assert_same_shapes(numpy.ndarray.__div__)
idiv_elemwise = div_elemwise.inplace_version()
idiv_elemwise.impl = assert_same_shapes(numpy.ndarray.__idiv__)
# class proto_div_elemwise(omega_op):
# def grad(x, y, gz):
# return div(gz, y), -div(mul(x, gz), sqr(y))
# class div_elemwise(proto_div_elemwise):
# impl = assert_same_shapes(numpy.ndarray.__div__)
class idiv_elemwise(proto_div_elemwise, inplace):
impl = assert_same_shapes(numpy.ndarray.__idiv__)
# class idiv_elemwise(proto_div_elemwise, inplace):
# impl = assert_same_shapes(numpy.ndarray.__idiv__)
def div_scalar_r(x, a):
return scale(x, inv_elemwise(a))
......@@ -412,28 +841,48 @@ def idiv_scalar_r(x, a):
## Scaling ##
class proto_neg(omega_op):
class neg(elemwise):
impl = numpy.ndarray.__neg__
def grad(x, gz):
return -gz
def c_foreach((x, ), (z, )):
return "z = -x;"
class neg(proto_neg):
impl = numpy.ndarray.__neg__
ineg = neg.inplace_version()
ineg.impl = lambda x: x.__imul__(-1)
class ineg(proto_neg, inplace):
impl = lambda x: x.__imul__(-1)
# class proto_neg(omega_op):
# def grad(x, gz):
# return -gz
# class neg(proto_neg):
# impl = numpy.ndarray.__neg__
class proto_inv_elemwise(omega_op):
def grad(x, gz):
raise NotImplemented
# class ineg(proto_neg, inplace):
# impl = lambda x: x.__imul__(-1)
class inv_elemwise(omega_op):
class inv_elemwise(elemwise):
impl = lambda x: 1 / x
def grad(x, gz):
return -gz
def c_foreach((x, ), (z, )):
return "z = 1 / x;"
class iinv_elemwise(omega_op, inplace):
def impl(x):
x[:] = 1 / x
iinv_elemwise = inv_elemwise.inplace_version()
# class proto_inv_elemwise(omega_op):
# def grad(x, gz):
# raise NotImplemented
# class inv_elemwise(omega_op):
# impl = lambda x: 1 / x
# class iinv_elemwise(omega_op, inplace):
# def impl(x):
# x[:] = 1 / x
## Dot product ##
......@@ -464,46 +913,116 @@ class array_copy(omega_op):
## Power ##
class proto_pow(omega_op):
class pow_elemwise(elemwise):
impl = assert_same_shapes(numpy.ndarray.__pow__)
def grad(x, s, gz):
return gz * s * (pow_elemwise(x, s-1.0))
def c_foreach((x, s), (z, )):
return "z = pow(x, s)"
class pow_elemwise(proto_pow):
impl = assert_same_shapes(numpy.ndarray.__pow__)
ipow_elemwise = pow_elemwise.inplace_version()
ipow_elemwise.impl = assert_same_shapes(numpy.ndarray.__ipow__)
# class proto_pow(omega_op):
# def grad(x, s, gz):
# return gz * s * (pow_elemwise(x, s-1.0))
class ipow_elemwise(proto_pow, inplace):
impl = assert_same_shapes(numpy.ndarray.__ipow__)
# class pow_elemwise(proto_pow):
# impl = assert_same_shapes(numpy.ndarray.__pow__)
# class ipow_elemwise(proto_pow, inplace):
# impl = assert_same_shapes(numpy.ndarray.__ipow__)
class pow_scalar_l(omega_op):
impl = tensor_scalar_op(numpy.ndarray.__pow__)
class pow_scalar_l(tensor_scalar_op):
impl = tensor_scalar_impl(lambda x, y: numpy.ndarray.__pow__(y, x))
def grad(x, s, gz):
return gz * x * (pow_scalar_l(s,x-1.0))
c_expr = "pow(a, x)"
class pow_scalar_r(omega_op):
impl = tensor_scalar_op(numpy.ndarray.__pow__)
class pow_scalar_r(tensor_scalar_op):
impl = tensor_scalar_impl(numpy.ndarray.__pow__)
def grad(x, s, gz):
return gz * s * (pow_scalar_r(x,s-1.0))
c_expr = "pow(x, a)"
class ipow_scalar_r(omega_op, inplace):
impl = tensor_scalar_op(numpy.ndarray.__ipow__)
def grad(x, s, gz):
return gz * s * (pow_scalar_r(x,s-1.0))
ipow_scalar_r = pow_scalar_r.inplace_version()
ipow_scalar_r.impl = tensor_scalar_impl(numpy.ndarray.__ipow__)
# class pow_scalar_l(omega_op):
# impl = tensor_scalar_impl(numpy.ndarray.__pow__)
# def grad(x, s, gz):
# return gz * x * (pow_scalar_l(s,x-1.0))
# class pow_scalar_r(omega_op):
# impl = tensor_scalar_impl(numpy.ndarray.__pow__)
# def grad(x, s, gz):
# return gz * s * (pow_scalar_r(x,s-1.0))
# class ipow_scalar_r(omega_op, inplace):
# impl = tensor_scalar_impl(numpy.ndarray.__ipow__)
# def grad(x, s, gz):
# return gz * s * (pow_scalar_r(x,s-1.0))
## Others ##
class minmax(omega_op):
class minmax(elemwise):
nout = 2
def impl(x):
return x.min, x.max
class fill(omega_op):
def c_alloc((x, ), (_min, _max)):
_min.data = numpy.ndarray((), x.dtype)
_max.data = numpy.ndarray((), x.dtype)
def c_init((x, ), (_min, _max)):
return """
_x_dtype min = _x[0];
_x_dtype max = _x[0];
"""
def c_foreach((x, ), (_min, _max)):
return """
if (x < min) min = x;
if (x > max) max = x;
"""
def c_finalize((x, ), (_min, _max)):
return """
_min[0] = min;
_max[0] = max;
"""
# class minmax(omega_op):
# nout = 2
# def impl(x):
# return x.min, x.max
class fill(elemwise):
impl = lambda model, value: (model * 0) + value
def c_init((model, _value), (z, )):
return "_z_dtype value = _value[0];"
def c_foreach((model, _value), (z, )):
return "z = value;"
class sum(omega_op):
impl = numpy.sum
def grad(x, gz):
return fill(x, gz)
ifill = fill.inplace_version()
# class fill(omega_op):
# impl = lambda model, value: (model * 0) + value
class sum(elemwise):
def c_alloc((x, ), (_sum, )):
_sum.data = numpy.ndarray((), dtype = x.data.dtype)
def c_init((x, ), (_sum, )):
return "_sum[0] = 0;"
def c_foreach((x, ), (_sum, )):
return "_sum[0] += x;"
# class sum(omega_op):
# impl = numpy.sum
# def grad(x, gz):
# return fill(x, gz)
## Array slicing ##
......
try:
from cutils_ext import *
except ImportError:
from scipy import weave
single_runner = """
if (!PyCObject_Check(py_cthunk)) {
PyErr_SetString(PyExc_ValueError,
"Argument to run_cthunk must be a PyCObject returned by the c_thunk method of an omega_op.");
return NULL;
}
int (*fn)(void*) = reinterpret_cast<int (*)(void*)>(PyCObject_AsVoidPtr(py_cthunk));
void* it = PyCObject_GetDesc(py_cthunk);
int failure = fn(it);
if (failure) {
return NULL;
}
"""
cthunk = object()
mod = weave.ext_tools.ext_module('cutils_ext')
mod.add_function(weave.ext_tools.ext_function('run_cthunk', single_runner, ['cthunk']))
mod.compile()
from cutils_ext import *
# from op import *
# from value import *
# from opt import *
# from env import *
# from prog import *
# from diff import *
# import dispatchers
from op import *
from ext import *
from lib import *
from link import *
from result import *
from env import *
from prog import *
from features import *
from opt import *
import graph
#import utils
import env
import tools
import utils
class Compiler:
def __init__(self, optimizer, features):
self.features = set(features)
self.features.update(optimizer.require())
self.optimizer = optimizer
def compile(self, inputs, outputs, features):
features = self.features.union(features)
e = env.Env(inputs, outputs, features, False)
self.optimizer.apply(e)
if not e.consistent():
raise env.InconsistencyError("The graph is inconsistent.")
return e
def __call__(self, inputs, outputs, features):
return self.compile(inputs, outputs, features)
# def __init__(self, inputs, outputs, preprocessors, features, optimizer):
# self.inputs = inputs
# self.outputs = outputs
# self.features = features
# self.optimizer = optimizer
# features = features + [tools.EquivTool] + optimizer.require()
# features = utils.uniq_features(features)
# self.env = env.Env(inputs,
# outputs,
# features,
# False)
# if not self.env.consistent():
# raise env.InconsistencyError("The graph is inconsistent.")
# self.__optimize__()
# self.thunks = [op.thunk() for op in self.order]
# def __optimize__(self):
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
# def equiv(self, r):
# return self.env.equiv(r)
# def __getitem__(self, r):
# return self.equiv(r)
# def __setitem__(self, r, value):
# if isinstance(r, tuple):
# for a, b in zip(r, value):
# self.__setitem__(a, b)
# else:
# self.equiv(r).set_value(value)
# def __call__(self, *args):
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
# import env
# import opt
# from value import AsValue
# class Prog:
# def __init__(self, inputs, outputs, optimizer):
# self.inputs = inputs
# self.outputs = outputs
# self.env = env.Env(inputs,
# outputs,
# False,
# op_db = env.OpDb,
# changed = env.ChangeListener,
# # pr = env.PrintListener,
# scope = env.ScopeListener)
# ## self.adjustments = adjustments
# self.optimizer = optimizer
# ## if self.adjustments:
# ## self.adjustments.apply(self.env)
# if not self.env.consistent():
# raise env.InconsistencyError("The graph is inconsistent.")
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
# print "==================="
# for op in self.order:
# print op
# print "==================="
# self.thunks = [op.thunk() for op in self.order]
# def equiv(self, v):
# v = AsValue(v)
# return self.env.equiv(v)
# def __getitem__(self, v):
# return self.equiv(v).storage
# def __setitem__(self, v, value):
# if isinstance(v, tuple):
# for a, b in zip(v, value):
# self.__setitem__(a, b)
# else:
# self.equiv(v).value = value
# def __call__(self, *args):
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
# def prog(i, o):
# if not isinstance(i, (list, tuple)):
# i = [i]
# if not isinstance(o, (list, tuple)):
# o = [o]
# i = [AsValue(input) for input in i]
# o = [AsValue(output) for output in o]
# return Prog(i,
# o,
# opt.TagFilterMultiOptimizer(opt.opt_registry, None, None))
from utils import OmegaError
class OmegaTypeError(OmegaError, TypeError):
pass
############################
# Dispatcher
############################
class Dispatcher(list):
all_dispatchers = {}
def __init__(self, name, description):
self.name = name
self.description = description
self.all_dispatchers[name] = self
def __call__(self, *inputs, **opts):
for candidate in self:
try:
return candidate(*inputs, **opts)
except OmegaTypeError:
continue
if opts:
s = " with options %s" % opts
else:
s = ""
raise OmegaTypeError("No candidate found for %s(%s) %s" \
% (self.name,
", ".join([input.__class__.__name__ for input in inputs]),
s))
def add_handler(self, x):
self.insert(0, x)
def fallback_handler(self, x):
self.append(x)
# Dispatchers for all python operators
Add = Dispatcher("Add", "x + y")
Subtract = Dispatcher("Subtract", "x - y")
Multiply = Dispatcher("Multiply", "x * y")
Divide = Dispatcher("Divide", "x / y")
FloorDivide = Dispatcher("FloorDivide", "x // y")
Modulo = Dispatcher("Modulo", "x % y")
Power = Dispatcher("Power", "x ** y")
Negate = Dispatcher("Negate", "-x")
Abs = Dispatcher("Abs", "abs(x)")
LeftShift = Dispatcher("LeftShift", "x << y")
RightShift = Dispatcher("RightShift", "x >> y")
Equals = Dispatcher("Equals", "x == y")
NotEquals = Dispatcher("NotEquals", "x != y")
Less = Dispatcher("Less", "x < y")
LessOrEqual = Dispatcher("LessOrEqual", "x <= y")
Greater = Dispatcher("Greater", "x > y")
GreaterOrEqual = Dispatcher("GreaterOrEqual", "x >= y")
Contains = Dispatcher("Contains", "x in y")
BinaryOr = Dispatcher("BinaryOr", "x | y")
BinaryAnd = Dispatcher("BinaryAnd", "x & y")
BinaryXor = Dispatcher("BinaryXor", "x ^ y")
BinaryInverse = Dispatcher("BinaryInverse", "~x")
# Dispatchers for special operations
Transpose = Dispatcher("Transpose", "x.T")
# Others
Log = Dispatcher("Log", 'log(x)')
Exp = Dispatcher("Exp", 'exp(x)')
Sin = Dispatcher("Sin", 'sin(x)')
Cos = Dispatcher("Cos", 'cos(x)')
Tan = Dispatcher("Tan", 'tan(x)')
############################
# PythonOperatorSupport
############################
class PythonOperatorSupport(object):
"""Support for built-in Python operators."""
# Common arithmetic operations
def __add__(self, x):
return Add(self, x)
def __radd__(self, x):
return Add(x, self)
def __sub__(self, x):
return Subtract(self, x)
def __rsub__(self, x):
return Subtract(x, self)
def __mul__(self, x):
return Multiply(self, x)
def __rmul__(self, x):
return Multiply(x, self)
def __div__(self, x):
return Divide(self, x)
def __rdiv__(self, x):
return Divide(x, self)
def __floordiv__(self, x):
return FloorDivide(self, x)
def __rfloordiv__(self, x):
return FloorDivide(x, self)
def __mod__(self, x):
return Modulo(self, x)
def __rmod__(self, x):
return Modulo(x, self)
def __pow__(self, x):
return Power(self, x)
def __rpow__(self, x):
return Power(x, self)
def __neg__(self):
return Negate(self)
def __abs__(self):
return Abs(self)
# Less common arithmetic operations
def __lshift__(self, x):
return LeftShift(self, x)
def __rlshift__(self, x):
return LeftShift(x, self)
def __rshift__(self, x):
return RightShift(self, x)
def __rrshift__(self, x):
return RightShift(x, self)
# Comparison operations
# def __eq__(self, x):
# return Equals(self, x)
# def __ne__(self, x):
# return NotEquals(self, x)
def __lt__(self, x):
return Less(self, x)
def __le__(self, x):
return LessOrEqual(self, x)
def __gt__(self, x):
return Greater(self, x)
def __ge__(self, x):
return GreaterOrEqual(self, x)
def __contains__(self, x):
return Contains(self, x)
# Binary operations
def __or__(self, x):
return BinaryOr(self, x)
def __ror__(self, x):
return BinaryOr(x, self)
def __and__(self, x):
return BinaryAnd(self, x)
def __rand__(self, x):
return BinaryAnd(x, self)
def __xor__(self, x):
return BinaryXor(self, x)
def __rxor__(self, x):
return BinaryXor(x, self)
def __invert__(self):
return BinaryInverse(self)
# Other operations
T = property(lambda self: Transpose(self))
norm = property(lambda self: Norm(self))
# Always nonzero
def __nonzero__(self):
return True
__all__ = globals().keys()
from copy import copy
import graph
## from value import Value, AsValue
from utils import ClsInit
from err import GofError, GofTypeError, PropagationError
from op import Op
from result import Result
from features import Listener, Orderings, Constraint, Tool
import utils
__all__ = ['InconsistencyError',
'Env']
# class AliasDict(dict):
# "Utility class to keep track of what Result has been replaced with what Result."
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
class InconsistencyError(GofError):
"""
This exception is raised by Env whenever one of the listeners marks
the graph as inconsistent.
"""
pass
class Env(graph.Graph):
"""
An Env represents a subgraph bound by a set of input results and a set of output
results. An op is in the subgraph iff it depends on the value of some of the Env's
inputs _and_ some of the Env's outputs depend on it. A result is in the subgraph
iff it is an input or an output of an op that is in the subgraph.
The Env supports the replace operation which allows to replace a result in the
subgraph by another, e.g. replace (x + x).out by (2 * x).out. This is the basis
for optimization in omega.
An Env can have listeners, which are instances of EnvListener. Each listener is
informed of any op entering or leaving the subgraph (which happens at construction
time and whenever there is a replacement). In addition to that, each listener can
implement the 'consistent' and 'ordering' methods (see EnvListener) in order to
restrict how ops in the subgraph can be related.
"""
### Special ###
def __init__(self, inputs, outputs, features = [], consistency_check = True): # **listeners):
"""
Create an Env which operates on the subgraph bound by the inputs and outputs
sets. If consistency_check is False, an illegal graph will be tolerated.
"""
self._features = {}
self._listeners = {}
self._constraints = {}
self._orderings = {}
self._tools = {}
# self._preprocessors = set()
# for feature in features:
# if issubclass(feature, tools.Preprocessor):
# preprocessor = feature()
# self._preprocessors.add(preprocessor)
# inputs, outputs = preprocessor.transform(inputs, outputs)
# The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = set(inputs)
self.outputs = set(outputs)
for feature_class in utils.uniq_features(features):
self.add_feature(feature_class, False)
# feature = feature_class(self)
# if isinstance(feature, tools.Listener):
# self._listeners.add(feature)
# if isinstance(feature, tools.Constraint):
# self._constraints.add(feature)
# if isinstance(feature, tools.Orderings):
# self._orderings.add(feature)
# if isinstance(feature, tools.Tool):
# self._tools.add(feature)
# feature.publish()
# All ops in the subgraph defined by inputs and outputs are cached in _ops
self._ops = set()
# Ditto for results
self._results = set(self.inputs)
# Set of all the results that are not an output of an op in the subgraph but
# are an input of an op in the subgraph.
# e.g. z for inputs=(x, y) and outputs=(x + (y - z),)
self._orphans = set()
# Maps results to ops that use them:
# if op.inputs[i] == v then (op, i) in self._clients[v]
self._clients = {}
# List of functions that undo the replace operations performed.
# e.g. to recover the initial graph one could write: for u in self.history.__reversed__(): u()
self.history = []
self.__import_r__(self.outputs)
if consistency_check:
self.validate()
### Public interface ###
def add_output(self, output):
self.outputs.add(output)
self.__import_r__([output])
def clients(self, r):
"Set of all the (op, i) pairs such that op.inputs[i] is r."
return self._clients.get(r, set())
def checkpoint(self):
"""
Returns an object that can be passed to self.revert in order to backtrack
to a previous state.
"""
return len(self.history)
def consistent(self):
"""
Returns True iff the subgraph is consistent and does not violate the
constraints set by the listeners.
"""
try:
self.validate()
except InconsistencyError:
return False
return True
def satisfy(self, x):
for feature_class in x.require():
self.add_feature(feature_class)
def add_feature(self, feature_class, do_import = True):
if feature_class in self._features:
return # the feature is already present
else:
for other_feature_class in self._features:
if issubclass(other_feature_class, feature_class):
return
elif issubclass(feature_class, other_feature_class):
self.__del_feature__(other_feature_class)
self.__add_feature__(feature_class, do_import)
def __add_feature__(self, feature_class, do_import):
if not issubclass(feature_class, (Listener, Constraint, Orderings, Tool)):
raise TypeError("features must be subclasses of Listener, Constraint, Orderings and/or Tools")
feature = feature_class(self)
if issubclass(feature_class, Listener):
self._listeners[feature_class] = feature
if do_import:
for op in self.io_toposort():
feature.on_import(op)
if issubclass(feature_class, Constraint):
self._constraints[feature_class] = feature
if issubclass(feature_class, Orderings):
self._orderings[feature_class] = feature
if issubclass(feature_class, Tool):
self._tools[feature_class] = feature
feature.publish()
self._features[feature_class] = feature
def __del_feature__(self, feature_class):
for set in [self._features, self._constraints, self._orderings, self._tools, self._listeners]:
try:
del set[feature_class]
except KeyError:
pass
# for i, feature in enumerate(self._features):
# if isinstance(feature, feature_class): # exact class or subclass, nothing to do
# return
# elif issubclass(feature_class, feature.__class__): # superclass, we replace it
# new_feature = feature_class(self)
# self._features[i] = new_feature
# break
# else:
# new_feature = feature_class(self)
# self._features.append(new_feature)
# if isinstance(new_feature, tools.Listener):
# for op in self.io_toposort():
# new_feature.on_import(op)
def get_feature(self, feature_class):
try:
return self._features[feature_class]
except KeyError:
for other_feature_class in self._features:
if issubclass(other_feature_class, feature_class):
return self._features[other_feature_class]
else:
raise
def has_feature(self, feature_class):
try:
self.get_feature(feature_class)
except:
return False
return True
def nclients(self, r):
"Same as len(self.clients(r))."
return len(self.clients(r))
def ops(self):
"All ops within the subgraph bound by env.inputs and env.outputs."
return self._ops
def has_op(self, op):
return op in self._ops
def orphans(self):
"""All results not within the subgraph bound by env.inputs and env.outputs, not in
env.inputs but required by some op."""
return self._orphans
def replace(self, r, new_r, consistency_check = True):
"""
This is the main interface to manipulate the subgraph in Env.
For every op that uses r as input, makes it use new_r instead.
This may raise a GofTypeError if the new result violates type
constraints for one of the target ops. In that case, no
changes are made.
If the replacement makes the graph inconsistent and the value
of consistency_check is True, this function will raise an
InconsistencyError and will undo the operation, leaving the
graph the way it was before the call to replace.
If consistency_check is False, the replacement will succeed
even if there is an inconsistency. A GofTypeError will still
be raised if there are type mismatches.
"""
# Assert that they are Result instances.
assert isinstance(r, Result)
assert isinstance(new_r, Result)
# Save where we are so we can backtrack
if consistency_check:
chk = self.checkpoint()
# The copy is required so undo can know what clients to move back!
clients = copy(self.clients(r))
# Messy checks so we know what to do if we are replacing an output
# result. Note that if v is an input result, we do nothing at all for
# now (it's not clear what it means to replace an input result).
was_output = False
new_was_output = False
if new_r in self.outputs:
new_was_output = True
if r in self.outputs:
was_output = True
self.outputs.remove(r)
self.outputs.add(new_r)
# The actual replacement operation occurs here. This might raise
# a GofTypeError
self.__move_clients__(clients, r, new_r)
# This function undoes the replacement.
def undo():
# Restore self.outputs
if was_output:
if not new_was_output:
self.outputs.remove(new_r)
self.outputs.add(r)
# Move back the clients. This should never raise an error.
self.__move_clients__(clients, new_r, r)
self.history.append(undo)
if consistency_check:
try:
self.validate()
except InconsistencyError, e:
self.revert(chk)
raise
def replace_all(self, d):
"""
For (r, new_r) in d.items(), replaces r with new_r. Checks for consistency at the
end and raises an InconsistencyError if the graph is not consistent. If an error is
raised, the graph is restored to what it was before.
"""
chk = self.checkpoint()
try:
for r, new_r in d.items():
self.replace(r, new_r, False)
except Exception, e:
self.revert(chk)
raise
try:
self.validate()
except InconsistencyError, e:
self.revert(chk)
raise
def results(self):
"All results within the subgraph bound by env.inputs and env.outputs and including them"
return self._results
def revert(self, checkpoint):
"""
Reverts the graph to whatever it was at the provided checkpoint (undoes all replacements).
A checkpoint at any given time can be obtained using self.checkpoint().
"""
while len(self.history) > checkpoint:
f = self.history.pop()
f()
def supplemental_orderings(self):
ords = {}
for ordering in self._orderings.values():
for op, prereqs in ordering.orderings().items():
ords.setdefault(op, set()).update(prereqs)
return ords
def toposort(self):
"""
Returns a list of ops in the order that they must be executed in order to preserve
the semantics of the graph and respect the constraints put forward by the listeners.
"""
ords = self.supplemental_orderings()
order = graph.io_toposort(self.inputs, self.outputs, ords)
return order
def validate(self):
for constraint in self._constraints.values():
constraint.validate()
return True
### Private interface ###
def __add_clients__(self, r, all):
self._clients.setdefault(r, set()).update(all)
def __remove_clients__(self, r, all):
if not all:
return
self._clients[r].difference_update(all)
if not self._clients[r]:
del self._clients[r]
def __import_r__(self, results):
for result in results:
owner = result.owner
if owner:
self.__import__(result.owner)
def __import__(self, op):
# We import the ops in topological order. We only are interested
# in new ops, so we use all results we know of as if they were the input set.
# (the functions in the graph module only use the input set to
# know where to stop going down)
new_ops = graph.io_toposort(self.results(), op.outputs)
for op in new_ops:
self.satisfy(op) # add the features required by this op
self._ops.add(op)
self._results.update(op.outputs)
for i, input in enumerate(op.inputs):
self.__add_clients__(input, [(op, i)])
if input not in self._results:
# This input is an orphan because if the op that
# produced it was in the subgraph, io_toposort
# would have placed it before, so we would have
# seen it (or it would already be in the graph)
self._orphans.add(input)
self._results.add(input)
for listener in self._listeners.values():
listener.on_import(op)
def __prune_r__(self, results):
for result in set(results):
if result in self.inputs:
continue
owner = result.owner
if owner:
self.__prune__(owner)
def __prune__(self, op):
for output in op.outputs:
# Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return
self._ops.remove(op)
self._results.difference_update(op.outputs)
for listener in self._listeners.values():
listener.on_prune(op)
for i, input in enumerate(op.inputs):
self.__remove_clients__(input, [(op, i)])
self.__prune_r__(op.inputs)
def __move_clients__(self, clients, r, new_r):
try:
# Try replacing the inputs
for op, i in clients:
op.set_input(i, new_r, False)
except GofTypeError, PropagationError:
# Oops!
for op, i in clients:
op.set_input(i, r, False)
raise
self.__remove_clients__(r, clients)
self.__add_clients__(new_r, clients)
# We import the new result in the fold
self.__import_r__([new_r])
for listener in self._listeners.values():
listener.on_rewire(clients, r, new_r)
# We try to get rid of the old one
self.__prune_r__([r])
def __str__(self):
return graph.as_string(self.inputs, self.outputs)
class GofError(Exception):
pass
class GofTypeError(GofError):
pass
class GofValueError(GofError):
pass
class PropagationError(GofError):
pass
from copy import copy
from op import Op
from lib import DummyOp
from result import Result
from features import Listener, Constraint, Orderings
from env import InconsistencyError
from utils import ClsInit
import graph
__all__ = ['Viewer', 'Destroyer', 'DestroyHandler', 'IONames', 'mark_outputs_as_destroyed']
## mul(*3 -> sub(*1 -> zeros((), float, C), sigmoid(dot(sigmoid(dot(*1, *2 -> ones((), float, C))), transpose(*2)))), fill(isqr(*3), 1.0))
class IONames:
"""
Requires assigning a name to each of this Op's inputs and outputs.
"""
__metaclass__ = ClsInit
input_names = ()
output_names = ()
@staticmethod
def __clsinit__(cls, name, bases, dct):
for names in ['input_names', 'output_names']:
if names in dct:
x = getattr(cls, names)
if isinstance(x, str):
x = [x,]
setattr(cls, names, x)
if isinstance(x, (list, tuple)):
x = [a for a in x]
setattr(cls, names, x)
for i, varname in enumerate(x):
if not isinstance(varname, str) or hasattr(cls, varname) or varname in ['inputs', 'outputs']:
raise TypeError("In %s: '%s' is not a valid input or output name" % (cls.__name__, varname))
# Set an attribute for the variable so we can do op.x to return the input or output named "x".
setattr(cls, varname,
property(lambda op, type=names.replace('_name', ''), index=i:
getattr(op, type)[index]))
else:
print 'ERROR: Class variable %s::%s is neither list, tuple, or string' % (name, names)
raise TypeError, str(names)
else:
setattr(cls, names, ())
# def __init__(self, inputs, outputs, use_self_setters = False):
# assert len(inputs) == len(self.input_names)
# assert len(outputs) == len(self.output_names)
# Op.__init__(self, inputs, outputs, use_self_setters)
def __validate__(self):
assert len(self.inputs) == len(self.input_names)
assert len(self.outputs) == len(self.output_names)
@classmethod
def n_inputs(cls):
return len(cls.input_names)
@classmethod
def n_outputs(cls):
return len(cls.output_names)
def get_by_name(self, name):
"""
Returns the input or output which corresponds to the given name.
"""
if name in self.input_names:
return self.input_names[self.input_names.index(name)]
elif name in self.output_names:
return self.output_names[self.output_names.index(name)]
else:
raise AttributeError("No such input or output name for %s: %s" % (self.__class__.__name__, name))
# class DestroyHandler(Listener, Constraint, Orderings):
# def __init__(self, env):
# self.parent = {}
# self.children = {}
# self.destroyers = {}
# self.paths = {}
# self.dups = set()
# self.cycles = set()
# self.env = env
# for input in env.inputs:
# self.parent[input] = None
# self.children[input] = set()
# def __path__(self, r):
# path = self.paths.get(r, None)
# if path:
# return path
# rval = [r]
# r = self.parent[r]
# while r:
# rval.append(r)
# r = self.parent[r]
# rval.reverse()
# for i, x in enumerate(rval):
# self.paths[x] = rval[0:i+1]
# return rval
# def __views__(self, r):
# children = self.children[r]
# if not children:
# return set([r])
# else:
# rval = set([r])
# for child in children:
# rval.update(self.__views__(child))
# return rval
# def __users__(self, r):
# views = self.__views__(r)
# rval = set()
# for view in views:
# for op, i in self.env.clients(view):
# rval.update(op.outputs)
# return rval
# def __pre__(self, op):
# rval = set()
# if op is None:
# return rval
# keep_going = False
# for input in op.inputs:
# foundation = self.__path__(input)[0]
# destroyers = self.destroyers.get(foundation, set())
# if destroyers:
# keep_going = True
# if op in destroyers:
# users = self.__users__(foundation)
# rval.update(users)
# if not keep_going:
# return set()
# rval.update(op.inputs)
# rval.difference_update(op.outputs)
# return rval
# def __detect_cycles_helper__(self, r, seq):
# if r in seq:
# self.cycles.add(tuple(seq[seq.index(r):]))
# return
# pre = self.__pre__(r.owner)
# for r2 in pre:
# self.__detect_cycles_helper__(r2, seq + [r])
# def __detect_cycles__(self, start, just_remove=False):
# users = self.__users__(start)
# users.add(start)
# for user in users:
# for cycle in copy(self.cycles):
# if user in cycle:
# self.cycles.remove(cycle)
# if just_remove:
# return
# for user in users:
# self.__detect_cycles_helper__(user, [])
# def get_maps(self, op):
# dmap = {}
# vmap = {}
# if isinstance(op, DestroyOp):
# dmap = op.destroy_map()
# if isinstance(op, ViewOp):
# vmap = op.view_map()
# return vmap, dmap
# # return getattr(op, 'view_map', lambda:{})(), \
# # getattr(op, 'destroy_map', lambda:{})()
# def on_import(self, op):
# view_map, destroy_map = self.get_maps(op)
# for input in op.inputs:
# self.parent.setdefault(input, None)
# for output in op.outputs:
# views = view_map.get(output, None)
# destroyed = destroy_map.get(output, None)
# if destroyed:
# self.parent[output] = None
# for input in destroyed:
# path = self.__path__(input)
# self.__add_destroyer__(path + [output])
# elif views:
# if len(views) > 1: #views was inputs before?
# raise Exception("Output is a view of too many inputs.")
# self.parent[output] = views[0]
# for input in views:
# self.children[input].add(output)
# else:
# self.parent[output] = None
# self.children[output] = set()
# for output in op.outputs:
# self.__detect_cycles__(output)
# # if destroy_map:
# # print "op: ", op
# # print "ord: ", [str(x) for x in self.orderings()[op]]
# # print
# def on_prune(self, op):
# view_map, destroy_map = self.get_maps(op)
# if destroy_map:
# destroyers = []
# for i, input in enumerate(op.inputs):
# destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
# for destroyer in destroyers:
# path = destroyer.get(op, [])
# if path:
# self.__remove_destroyer__(path)
# if view_map:
# for i, input in enumerate(op.inputs):
# self.children[input].difference_update(op.outputs)
# for output in op.outputs:
# try:
# del self.paths[output]
# except:
# pass
# self.__detect_cycles__(output, True)
# for i, output in enumerate(op.outputs):
# try:
# del self.parent[output]
# except:
# pass
# del self.children[output]
# def __add_destroyer__(self, path):
# foundation = path[0]
# target = path[-1]
# op = target.owner
# destroyers = self.destroyers.setdefault(foundation, {})
# path = destroyers.setdefault(op, path)
# if len(destroyers) > 1:
# self.dups.add(foundation)
# def __remove_destroyer__(self, path):
# foundation = path[0]
# target = path[-1]
# op = target.owner
# destroyers = self.destroyers[foundation]
# del destroyers[op]
# if not destroyers:
# del self.destroyers[foundation]
# elif len(destroyers) == 1 and foundation in self.dups:
# self.dups.remove(foundation)
# def on_rewire(self, clients, r_1, r_2):
# path_1 = self.__path__(r_1)
# path_2 = self.__path__(r_2)
# prev = set()
# for op, i in clients:
# prev.update(op.outputs)
# foundation = path_1[0]
# destroyers = self.destroyers.get(foundation, {}).items()
# for op, path in destroyers:
# if r_1 in path:
# idx = path.index(r_1)
# self.__remove_destroyer__(path)
# if not (idx > 0 and path[idx - 1] in prev):
# continue
# index = path.index(r_1)
# new_path = path_2 + path[index+1:]
# self.__add_destroyer__(new_path)
# for op, i in clients:
# view_map, _ = self.get_maps(op)
# for output, inputs in view_map.items():
# if r_2 in inputs:
# assert self.parent[output] == r_1
# self.parent[output] = r_2
# self.children[r_1].remove(output)
# self.children[r_2].add(output)
# for view in self.__views__(r_1):
# try:
# del self.paths[view]
# except:
# pass
# for view in self.__views__(r_2):
# try:
# del self.paths[view]
# except:
# pass
# self.__detect_cycles__(r_1)
# self.__detect_cycles__(r_2)
# def validate(self):
# if self.dups:
# raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
# elif self.cycles:
# raise InconsistencyError("There are cycles: %s" % self.cycles)
# else:
# return True
# def orderings(self):
# ords = {}
# for foundation, destroyers in self.destroyers.items():
# for op in destroyers.keys():
# ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
# return ords
class DestroyHandler(Listener, Constraint, Orderings):
def __init__(self, env):
self.parent = {}
self.children = {}
self.destroyers = {}
self.paths = {}
self.dups = set()
self.cycles = set()
self.env = env
for input in env.inputs:
# self.parent[input] = None
self.children[input] = set()
def __path__(self, r):
path = self.paths.get(r, None)
if path:
return path
rval = [r]
r = self.parent.get(r, None) ### ???
while r:
rval.append(r)
r = self.parent.get(r, None)
rval.reverse()
for i, x in enumerate(rval):
self.paths[x] = rval[0:i+1]
return rval
def __views__(self, r):
children = self.children[r]
if not children:
return set([r])
else:
rval = set([r])
for child in children:
rval.update(self.__views__(child))
return rval
def __users__(self, r):
views = self.__views__(r)
rval = set()
for view in views:
for op, i in self.env.clients(view):
rval.update(op.outputs)
return rval
def __pre__(self, op):
rval = set()
if op is None:
return rval
keep_going = False
for input in op.inputs:
foundation = self.__path__(input)[0]
destroyers = self.destroyers.get(foundation, set())
if destroyers:
keep_going = True
if op in destroyers:
users = self.__users__(foundation)
rval.update(users)
# if not keep_going:
# return set()
rval.update(op.inputs)
rval.difference_update(op.outputs)
return rval
def __detect_cycles_helper__(self, r, seq):
# print "!! ", r, seq
if r in seq:
self.cycles.add(tuple(seq[seq.index(r):]))
return
pre = self.__pre__(r.owner)
for r2 in pre:
self.__detect_cycles_helper__(r2, seq + [r])
def __detect_cycles__(self, start, just_remove=False):
# print "!!! ", start
users = self.__users__(start)
users.add(start)
for user in users:
for cycle in copy(self.cycles):
if user in cycle:
self.cycles.remove(cycle)
if just_remove:
return
for user in users:
self.__detect_cycles_helper__(user, [])
def get_maps(self, op):
dmap = {}
vmap = {}
if isinstance(op, Destroyer):
dmap = op.destroy_map()
if isinstance(op, Viewer):
vmap = op.view_map()
return vmap, dmap
# return getattr(op, 'view_map', lambda:{})(), \
# getattr(op, 'destroy_map', lambda:{})()
def on_import(self, op):
view_map, destroy_map = self.get_maps(op)
# for input in op.inputs:
# self.parent.setdefault(input, None)
for i, output in enumerate(op.outputs):
views = view_map.get(output, None)
destroyed = destroy_map.get(output, None)
if destroyed:
# self.parent[output] = None
if isinstance(destroyed, Result):
destroyed = [destroyed]
for input in destroyed:
path = self.__path__(input)
self.__add_destroyer__(path + [output])
elif views:
if isinstance(views, Result):
views = [views]
if len(views) > 1: #views was inputs before?
raise Exception("Output is a view of too many inputs.")
self.parent[output] = views[0]
for input in views:
self.children[input].add(output)
# else:
# self.parent[output] = None
self.children[output] = set()
for output in op.outputs:
self.__detect_cycles__(output)
# if destroy_map:
# print "op: ", op
# print "ord: ", [str(x) for x in self.orderings()[op]]
# print
def on_prune(self, op):
view_map, destroy_map = self.get_maps(op)
if destroy_map:
destroyers = []
for i, input in enumerate(op.inputs):
destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
for destroyer in destroyers:
path = destroyer.get(op, [])
if path:
self.__remove_destroyer__(path)
if view_map:
for i, input in enumerate(op.inputs):
self.children[input].difference_update(op.outputs)
for output in op.outputs:
try:
del self.paths[output]
except:
pass
self.__detect_cycles__(output, True)
for i, output in enumerate(op.outputs):
try:
del self.parent[output]
except:
pass
del self.children[output]
def __add_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(op, path)
if len(destroyers) > 1:
self.dups.add(foundation)
def __remove_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers[foundation]
del destroyers[op]
if not destroyers:
del self.destroyers[foundation]
elif len(destroyers) == 1 and foundation in self.dups:
self.dups.remove(foundation)
def on_rewire(self, clients, r_1, r_2):
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
prev = set()
for op, i in clients:
prev.update(op.outputs)
foundation = path_1[0]
destroyers = self.destroyers.get(foundation, {}).items()
for op, path in destroyers:
if r_1 in path:
idx = path.index(r_1)
self.__remove_destroyer__(path)
if not (idx > 0 and path[idx - 1] in prev):
continue
index = path.index(r_1)
new_path = path_2 + path[index+1:]
self.__add_destroyer__(new_path)
for op, i in clients:
view_map, _ = self.get_maps(op)
for output, inputs in view_map.items():
if r_2 in inputs:
assert self.parent.get(output, None) == r_1
self.parent[output] = r_2
self.children[r_1].remove(output)
self.children[r_2].add(output)
for view in self.__views__(r_1):
try:
del self.paths[view]
except:
pass
for view in self.__views__(r_2):
try:
del self.paths[view]
except:
pass
self.__detect_cycles__(r_1)
self.__detect_cycles__(r_2)
def validate(self):
if self.dups:
raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
elif self.cycles:
raise InconsistencyError("There are cycles: %s" % self.cycles)
else:
return True
def orderings(self):
ords = {}
for foundation, destroyers in self.destroyers.items():
for op in destroyers.keys():
ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
return ords
class Viewer:
"""
Represents an operation such that one or more of its outputs share
storage with one or more of its inputs so changing one might
change the other. All inputs are assumed to be left intact.
"""
def view_map(self):
"""
Returns a dictionary which maps an output to the list of
inputs of which it is a view (with which it might share
internal structures).
By default, supposes that the first output is a view of
the first input.
"""
return {self.out: [self.inputs[0]]}
class Destroyer:
"""
Represents an operation which acts in place on one or several of
its inputs. As a result of this Op, the data contained in the
inputs might be changed.
"""
__require__ = DestroyHandler
def destroy_map(self):
"""
Returns a dictionary which maps an output to the list of
inputs which it destroys.
By default, supposes that the first input is overwritten
by the first output.
"""
return {self.out: [self.inputs[0]]}
class Return(DummyOp, Destroyer):
"""
Dummy op which represents the action of returning its input
value to an end user. It "destroys" its input to prevent any
other Op to overwrite it.
"""
pass
def mark_outputs_as_destroyed(outputs):
return [Return(output).out for output in outputs]
# class BuildableFromInputs:
# @classmethod
# def from_inputs(cls, *inputs):
# return cls(inputs, self.gen_outputs())
# def gen_outputs(self):
# raise NotImplementedError
from copy import copy
from op import Op
import result
import graph
import utils
from random import shuffle
__all__ = ['Feature',
'Listener',
'Constraint',
'Orderings',
'Tool',
# 'Preprocessor',
'EquivTool',
'InstanceFinder',
'PrintListener',
'ChangeListener',
# 'DestroyPreprocessor',
# 'DestroyHandler'
]
class Feature(object):
def __init__(self, env):
self.env = env
class Listener(Feature):
def on_import(self, op):
pass
def on_prune(self, op):
pass
def on_rewire(self, clients, r, new_r):
pass
class Constraint(Feature):
def validate(self):
return True
class Orderings(Feature):
def orderings(self):
return {}
class Tool(Feature):
def publish(self):
pass
# class Preprocessor(Feature):
# def transform(self, inputs, outputs):
# return inputs, outputs
# def __call__(self, inputs, outputs):
# return self.transform(inputs, outputs)
# class Optimization(object):
# def require(self):
# return []
# def apply(self, env):
# pass
# def __call__(self, env):
# return self.apply(env)
# Optimization
# * require <tool_class>*
# * apply
# Prog
# * __init__
# * inputs
# * outputs
# * listeners, constraints, orderings
# * dispatched by isinstance Listener, etc.
# * {tool_class: preferred_implementation, ...}
class EquivTool(Listener, Tool, dict):
def on_rewire(self, clients, r, new_r):
repl = self(new_r)
if repl is r:
self.ungroup(r, new_r)
elif repl is not new_r:
raise Exception("Improper use of EquivTool!")
else:
self.group(new_r, r)
def publish(self):
self.env.equiv = self
def group(self, main, *keys):
"Marks all the keys as having been replaced by the Result main."
keys = [key for key in keys if key is not main]
if self.has_key(main):
raise Exception("Only group results that have not been grouped before.")
for key in keys:
if self.has_key(key):
raise Exception("Only group results that have not been grouped before.")
if key is main:
continue
self.setdefault(key, main)
def ungroup(self, main, *keys):
"Undoes group(main, *keys)"
keys = [key for key in keys if key is not main]
for key in keys:
if self[key] is main:
del self[key]
def __call__(self, key):
"Returns the currently active replacement for the given key."
next = self.get(key, None)
while next:
key = next
next = self.get(next, None)
return key
class InstanceFinder(Listener, Tool, dict):
def __init__(self, env):
self.env = env
def all_bases(self, cls):
return utils.all_bases(cls, lambda cls: issubclass(cls, Op))
# return [cls for cls in utils.all_bases(cls) if issubclass(cls, Op)]
def on_import(self, op):
for base in self.all_bases(op.__class__):
self.setdefault(base, set()).add(op)
def on_prune(self, op):
for base in self.all_bases(op.__class__):
self[base].remove(op)
if not self[base]:
del self[base]
def __query__(self, cls):
all = [x for x in self.get(cls, [])]
shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
while all:
next = all.pop()
if next in self.env.ops():
yield next
def query(self, cls):
return self.__query__(cls)
def publish(self):
self.env.get_instances_of = self.query
class PrintListener(Listener):
def __init__(self, env, active = True):
self.env = env
self.active = active
if active:
print "-- initializing"
def on_import(self, op):
if self.active:
print "-- importing: %s" % graph.as_string(self.env.inputs, op.outputs)
def on_prune(self, op):
if self.active:
print "-- pruning: %s" % graph.as_string(self.env.inputs, op.outputs)
def on_rewire(self, clients, r, new_r):
if self.active:
if r.owner is None:
rg = id(r) #r.name
else:
rg = graph.as_string(self.env.inputs, r.owner.outputs)
if new_r.owner is None:
new_rg = id(new_r) #new_r.name
else:
new_rg = graph.as_string(self.env.inputs, new_r.owner.outputs)
print "-- moving from %s to %s" % (rg, new_rg)
class ChangeListener(Listener):
def __init__(self, env):
self.change = False
def on_import(self, op):
self.change = True
def on_prune(self, op):
self.change = True
def on_rewire(self, clients, r, new_r):
self.change = True
def __call__(self, value = "get"):
if value == "get":
return self.change
else:
self.change = value
# class SuperFinder(Listener, Tool, dict):
# def __init__(self, env, props):
# self.env = env
# self.props = props
# def on_import(self, op):
# for prop, value in self.props(op).items():
# self.setdefault(prop, {}).setdefault(value, set()).add(op)
# def on_prune(self, op):
# for prop, value in self.props(op).items():
# self[prop][value].remove(op)
# if len(self[prop][value]) == 0:
# del self[prop][value]
# if len(self[prop]) == 0:
# del self[prop]
# def __query__(self, order, template):
# all = []
# for prop, value in template.items():
# all += [x for x in self.get(prop, {}).get(value, set())]
# # If not None, the order option requires the order listener to be included in the env under the name 'order'
# if order == 'o->i':
# all.sort(lambda op1, op2: self.env.order[op1].__cmp__(self.env.order[op2]))
# elif order == 'i->o':
# all.sort(lambda op1, op2: self.env.order[op2].__cmp__(self.env.order[op1]))
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, **template):
# return self.__query__(None, template)
# def query_downstream(self, **template):
# return self.__query__('i->o', template)
# def query_upstream(self, **template):
# return self.__query__('o->i', template)
# def publish(self):
# self.env.query = self.query
from copy import copy
from result import Result, BrokenLink, BrokenLinkError
from op import Op
import utils
__all__ = ['inputs',
'results_and_orphans', 'results', 'orphans',
'ops',
'clone', 'clone_get_equiv',
'io_toposort',
'as_string',
'Graph']
def inputs(o, repair = False):
"""
o -> list of output Results
Returns the set of inputs necessary to compute the outputs in o
such that input.owner is None.
"""
results = set()
def seek(r):
if isinstance(r, BrokenLink):
raise BrokenLinkError
op = r.owner
if op is None:
results.add(r)
else:
for i in range(len(op.inputs)):
try:
seek(op.inputs[i])
except BrokenLinkError:
if repair:
op.refresh()
seek(op.inputs[i])
else:
raise
for output in o:
seek(output)
return results
def results_and_orphans(i, o):
"""
i -> list of input Results
o -> list of output Results
Returns the pair (results, orphans). The former is the set of
Results that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results = set(o)
results.update(i)
incomplete_paths = []
def helper(r, path):
if isinstance(r, BrokenLink):
raise BrokenLinkError
if r in i:
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
return results, orphans
def ops(i, o):
"""
i -> list of input Results
o -> list of output Results
Returns the set of ops that are contained within the subgraph
that lies between i and o, including the owners of the Results in
o and intermediary ops between i and o, but not the owners of the
Results in i.
"""
ops = set()
results, orphans = results_and_orphans(i, o)
for r in results:
if r not in i and r not in orphans:
ops.add(r.owner)
return ops
def results(i, o):
"""
i -> list of input Results
o -> list of output Results
Returns the set of Results that are involved in the subgraph
that lies between i and o. This includes i, o, orphans(i, o)
and all values of all intermediary steps from i to o.
"""
return results_and_orphans(i, o)[0]
def orphans(i, o):
"""
i -> list of input Results
o -> list of output Results
Returns the set of Results which one or more Results in o depend
on but are neither in i nor in the subgraph that lies between
i and o.
e.g. orphans([x], [(x+y).out]) => [y]
"""
return results_and_orphans(i, o)[1]
def clone(i, o):
"""
i -> list of input Results
o -> list of output Results
Copies the subgraph contained between i and o and returns the
outputs of that copy (corresponding to o). The input Results in
the list are _not_ copied and the new graph refers to the
originals.
"""
new_o, equiv = clone_get_equiv(i, o)
return new_o
def clone_get_equiv(i, o, copy_inputs = False):
"""
i -> list of input Results
o -> list of output Results
Returns (new_o, equiv) where new_o are the outputs of a copy of
the whole subgraph bounded by i and o and equiv is a dictionary
that maps the original ops and results found in the subgraph to
their copy (akin to deepcopy's memo). See clone for more details.
"""
d = {}
for op in ops(i, o):
d[op] = copy(op)
for old_op, op in d.items():
for old_output, output in zip(old_op.outputs, op.outputs):
d[old_output] = output
for i, input in enumerate(op.inputs):
owner = input.owner
if owner in d:
op._inputs[i] = d[owner].outputs[input._index]
return [[d[output] for output in o], d]
def io_toposort(i, o, orderings = {}):
"""
i -> list of input Results
o -> list of output Results
orderings -> {op: [requirements for op]} (defaults to {})
Returns an ordered list of Ops that belong in the subgraph between
i and o which respects the following constraints:
- all inputs in i are assumed to be already computed
- the Ops that compute an Op's inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
prereqs_d.setdefault(op, set()).update(set([input.owner for input in op.inputs if input.owner and input.owner in all]))
# prereqs_d[op] = set([input.owner for input in op.inputs if input.owner and input.owner in all])
return utils.toposort(prereqs_d)
def as_string(i, o):
"""
i -> list of input Results
o -> list of output Results
Returns a string representation of the subgraph between i and o. If the same
Op is used by several other ops, the first occurrence will be marked as
'*n -> description' and all subsequent occurrences will be marked as '*n',
where n is an id number (ids are attributed in an unspecified order and only
exist for viewing convenience).
"""
multi = set()
seen = set()
for op in ops(i, o):
for input in op.inputs:
op2 = input.owner
if input in i or op2 is None:
continue
if op2 in seen:
multi.add(op2)
else:
seen.add(input.owner)
multi = [x for x in multi]
done = set()
def multi_index(x):
try:
return multi.index(x) + 1
except:
return 999
def describe(x, first = False):
if isinstance(x, Result):
done.add(x)
if x.owner is not None and x not in i:
op = x.owner
idx = op.outputs.index(x)
if idx:
s = describe(op, first) + "." + str(idx)
else:
s = describe(op, first)
return s
else:
return str(id(x))
elif isinstance(x, Op):
if x in done:
return "*%i" % multi_index(x)
else:
done.add(x)
if not first and hasattr(x, 'name') and x.name is not None:
return x.name
s = x.__class__.__name__ + "(" + ", ".join([describe(v) for v in x.inputs]) + ")"
if x in multi:
return "*%i -> %s" % (multi_index(x), s)
else:
return s
else:
raise TypeError("Cannot print type: %s" % x.__class__)
return "[" + ", ".join([describe(x, True) for x in o]) + "]"
# Op.__str__ = lambda self: as_string(inputs(self.outputs), self.outputs)[1:-1]
# Result.__str__ = lambda self: as_string(inputs([self]), [self])[1:-1]
class Graph:
def __init__(self, inputs, outputs):
self.inputs = inputs
self.outputs = outputs
def ops(self):
return ops(self.inputs, self.outputs)
def values(self):
return values(self.inputs, self.outputs)
def orphans(self):
return orphans(self.inputs, self.outputs)
def io_toposort(self):
return io_toposort(self.inputs, self.outputs)
def toposort(self):
return self.io_toposort()
def clone(self):
o = clone(self.inputs, self.outputs)
return Graph(self.inputs, o)
def __str__(self):
return as_string(self.inputs, self.outputs)
from op import Op
from result import Result #, HolderResult
from utils import ClsInit, Keyword
import opt
import env
import features
import ext
__all__ = ['UNCOMPUTED',
'UNDEFINED',
'current_mode',
'set_mode',
'build_mode',
'eval_mode',
'build_eval_mode',
'pop_mode',
'PythonR',
'DummyOp',
'DummyRemover',
'PythonOp',
'PythonOpt',
'COp',
'DualImplOp']
UNCOMPUTED = Keyword("UNCOMPUTED", False)
UNDEFINED = Keyword("UNDEFINED", False)
class ForbidConstantOverwrite(features.Listener, features.Constraint):
def __init__(self, env):
self.env = env
self.bad = set()
def root_inputs(self, input):
owner = input.owner
if owner and isinstance(owner, ext.Viewer):
view_map = owner.view_map()
if input in view_map:
answer = []
for input2 in view_map[input]:
answer += owner.root_inputs(input2)
return answer
else:
return [input]
def on_import(self, op):
if isinstance(op, ext.Destroyer):
for output, inputs in op.destroy_map().items():
for input in inputs:
for root_input in self.root_inputs(input):
if getattr(root_input, 'constant', False):
self.bad.add(op)
return
def on_prune(self, op):
if op in self.bad:
self.bad.remove(op)
def on_rewire(self, clients, r, new_r):
for op, i in clients:
self.on_prune(op)
self.on_import(op)
def validate(self):
if self.bad:
raise env.InconsistencyError("The following ops overwrite a constant value: %s" % self.bad)
else:
return True
class PythonR(Result):
__slots__ = ['data', 'constant', 'up_to_date']
def __init__(self, x = None, constant = False):
self.constant = False
self.set_value(x)
self.constant = constant
self.up_to_date = True
def set_value(self, value):
if self.constant:
raise Exception("This Result is a constant. Its value cannot be changed.")
if value is None or value is UNCOMPUTED:
self.data = UNCOMPUTED
elif isinstance(value, PythonR):
self.set_value(value.data)
else:
self.data = value
self.up_to_date = True
def __str__(self):
return str(self.data)
def __repr__(self):
return repr(self.data)
def perform(self):
if self.owner:
self.owner.perform()
def compute(self):
if self.owner:
self.owner.compute()
class PythonOp(Op):
__metaclass__ = ClsInit
__mode__ = ['build_eval']
nout = 1
@staticmethod
def __clsinit__(cls, name, bases, dct):
# make impl a static method
impl = cls.impl
if hasattr(cls.impl, 'im_func'):
impl = impl.im_func
cls.impl = staticmethod(impl)
def __new__(cls, *inputs, **kwargs):
op = Op.__new__(cls)
op.__init__(*inputs)
mode = kwargs.get('mode', None) or cls.current_mode()
if mode == 'eval':
op.perform()
if op.nout == 1:
return op.out.data
else:
return [output.data for output in op.outputs]
elif mode == 'build_eval':
op.perform()
if op.nout == 1:
return op.out
else:
return op.outputs
def __init__(self, *inputs):
Op.__init__(self, inputs, self.gen_outputs())
def __validate__(self):
for input in self.inputs:
assert isinstance(input, PythonR)
@classmethod
def current_mode(cls):
return cls.__mode__[-1]
@classmethod
def set_mode(cls, mode):
cls.__mode__.append(mode)
@classmethod
def build_mode(cls):
cls.set_mode('build')
@classmethod
def eval_mode(cls):
cls.set_mode('eval')
@classmethod
def build_eval_mode(cls):
cls.set_mode('build_eval')
@classmethod
def pop_mode(cls):
if len(cls.__mode__) == 1:
raise Exception("There's only one mode left on the stack.")
else:
cls.__mode__.pop()
def gen_outputs(self):
return [PythonR() for i in xrange(self.nout)]
def root_inputs(self, input):
owner = input.owner
if owner and isinstance(owner, ext.Viewer):
view_map = owner.view_map()
if input in view_map:
answer = []
for input2 in view_map[input]:
answer += owner.root_inputs(input2)
return answer
else:
return [input]
else:
return [input]
def input_is_up_to_date(self, input):
answer = True
for input in self.root_inputs(input):
answer &= input.up_to_date
return answer
def input_is_constant(self, input):
answer = False
for input in self.root_inputs(input):
answer |= input.constant
return answer
# def input_is_up_to_date(self, input):
# if not input.up_to_date:
# return False
# owner = input.owner
# if owner and isinstance(owner, ext.Viewer):
# view_map = owner.view_map()
# if input in view_map:
# answer = True
# for input2 in view_map[input]:
# answer &= owner.input_is_up_to_date(input2)
# return answer
# return True
def check_input(self, input):
if input.data is UNCOMPUTED:
raise ValueError("Uncomputed input: %s in %s" % (input, self))
if not self.input_is_up_to_date(input):
raise ValueError("Input is out of date: %s in %s" % (input, self))
def perform(self):
exc = set()
if isinstance(self, ext.Destroyer):
for output, inputs in self.destroy_map().items():
exc.update(inputs)
for input in inputs:
if self.input_is_constant(input):
raise ValueError("Input is constant: %s" % input)
for input in exc:
self.check_input(input)
input.up_to_date = False
for input in self.inputs:
if input not in exc:
self.check_input(input)
try:
results = self._impl()
except Exception, e:
print "Error in %s: %s" % (self, e)
raise
if self.nout == 1:
self.out.set_value(results)
else:
assert self.nout == len(results)
for result, output in zip(results, self.outputs):
output.set_value(result)
def _perform(self):
results = self._impl()
if self.nout == 1:
self.out.set_value(results)
else:
assert self.nout == len(results)
for result, output in zip(results, self.outputs):
output.set_value(result)
def compute(self):
for input in self.inputs:
if input.data is UNCOMPUTED:
if input.owner:
input.owner.compute()
else:
raise Exception("Uncomputed input: %s in %s" % (input, self))
self.perform()
def _impl(self):
return self.impl(*[input.data for input in self.inputs])
def impl(*args):
raise NotImplementedError("This op has no implementation.")
__require__ = ForbidConstantOverwrite
def __copy__(self):
"""
Copies the inputs list shallowly and copies all the outputs
because of the one owner per output restriction.
"""
# new_inputs = copy(op.inputs)
# # We copy the outputs because they are tied to a single Op.
# new_outputs = [copy(output) for output in op.outputs]
build_mode()
op = self.__class__(*self.inputs)
pop_mode()
# op._inputs = new_inputs
# op._outputs = new_outputs
# for i, output in enumerate(op.outputs):
# # We adjust _owner and _index manually since the copies
# # point to the previous op (self).
# output._owner = op
# output._index = i
if isinstance(op, (list, tuple)):
return op[0].owner
return op.owner
current_mode = PythonOp.current_mode
set_mode = PythonOp.set_mode
build_mode = PythonOp.build_mode
eval_mode = PythonOp.eval_mode
build_eval_mode = PythonOp.build_eval_mode
pop_mode = PythonOp.pop_mode
class PythonOpt(opt.Optimizer):
def __init__(self, opt):
self.opt = opt
def optimize(self, env):
PythonOp.build_mode()
self.opt.optimize(env)
PythonOp.pop_mode()
class DummyOp(Op):
def __init__(self, input):
Op.__init__(self, [input], [Result()])
def thunk(self):
return lambda:None
DummyRemover = opt.OpRemover(DummyOp)
# literals_db = {}
# def literal(x):
# if x in literals_db:
# return literals_db.get(x)
# else:
# ret = PythonR(x, constant = True)
# liberals_db[x] = ret
# return ret
class COp(Op):
def thunk(self):
cc.compile([self])
def c_libs(self):
return []
def c_imports(self):
return []
def c_impl(self):
raise NotImplementedError("Provide the operation's behavior here.")
class DualImplOp(PythonOp, COp):
language = 'c'
supported_languages = 'c', 'python'
def thunk(self, language = None):
"""
Returns a thunk that does the operation on the inputs and stores the
results in the outputs. The language parameter defaults to self.language
and determines which implementation to use.
"""
if not language:
language = self.language
if language == 'c':
return COp.thunk(self)
elif language == 'python':
return PythonOp.thunk(self)
elif language == 'all':
return [self.thunk(lang) for lang in self.supported_languages]
else:
raise ValueError("language should be any of %s or 'all', not '%s'" % (self.supported_languages, language))
def compare_implementations(self,
samples,
setter = lambda res, v: res.set_value(v),
cmp = lambda x, y: x == y):
"""
Compares the different implementations of this operation on a
list of input values to verify that they behave the same. The
input values are put in the Result instances using the setter
function (defaults to set_value). The output lists are
compared using the cmp predicate (defaults to ==).
"""
for sample in samples:
for input, v in zip(self.inputs, sample):
input.set_value(v)
self.thunk('python')()
# we must copy the outputs because they will be overwritten
results_py = [copy(output).extract() for output in self.outputs]
# we redo the assignment because the Op might be destructive,
# in which case the inputs might not be correct anymore
for input, v in zip(self.inputs, sample):
input.set_value(v)
self.thunk('c')()
results_c = [copy(output).extract() for output in self.outputs]
assert cmp(results_py, results_c)
def perform_linker(env, target = None):
order = env.toposort()
thunks = [op._perform for op in order]
def ret():
for thunk in thunks:
thunk()
if not target:
return ret
else:
raise NotImplementedError("Cannot write thunk representation to a file.")
def cthunk_linker(env):
order = env.toposort()
thunks = []
cstreak = []
def append_cstreak():
if cstreak:
thunks.append(cutils.create_cthunk_loop(*cstreak))
cstreak = []
def ret():
for thunk in thunks:
thunk()
for op in order:
if hasattr(op, 'cthunk'):
cstreak.append(op.cthunk())
else:
append_cstreak()
thunks.append(op.perform)
if len(thunks) == 1:
return thunks[0]
else:
return ret
from copy import copy
import graph
from env import Env, EnvListener
class PrintListener(EnvListener):
def __init__(self, env, active = True):
self.env = env
self.active = active
if active:
print "-- initializing"
def on_import(self, op):
if self.active:
print "-- importing: %s" % graph.as_string(self.env.inputs, op.outputs)
def on_prune(self, op):
if self.active:
print "-- pruning: %s" % graph.as_string(self.env.inputs, op.outputs)
def on_rewire(self, clients, v, new_v):
if self.active:
if v.owner is None:
vg = v.name
else:
vg = graph.as_string(self.env.inputs, v.owner.outputs)
if new_v.owner is None:
new_vg = new_v.name
else:
new_vg = graph.as_string(self.env.inputs, new_v.owner.outputs)
print "-- moving from %s to %s" % (vg, new_vg)
class ChangeListener(EnvListener):
def __init__(self, env):
self.change = False
def on_import(self, op):
self.change = True
def on_prune(self, op):
self.change = True
def on_replace(self, v, new_v):
self.change = True
def __call__(self, value = "get"):
if value == "get":
return self.change
else:
self.change = value
class InstanceFinder(EnvListener, dict):
def __init__(self, env):
self.env = env
def all_bases(self, cls):
rval = set(cls)
for base in cls.__bases__:
rval.add(self.all_bases(base))
return [cls for cls in rval if issubclass(cls, Op)]
def on_import(self, op):
for base in self.all_bases(op.__class__):
self.setdefault(base, set()).add(op)
def on_prune(self, op):
for base in self.all_bases(op.__class__):
self[base].remove(op)
if not self[base]:
del self[base]
def __query__(self, cls):
all = [x for x in self.get(cls, [])]
while all:
next = all.pop()
if next in self.env.ops():
yield next
def query(self, cls):
return self.__query__(cls)
# class GraphOrder(EnvListener, dict):
# def init(self, graph):
# self.graph = graph
# def __adjust__(self, op, minimum):
# if not op or self[op] >= minimum:
# return
# self[op] = minimum
# for output in op.outputs:
# for client, i in output.clients:
# self.__adjust__(client, minimum + 1)
# def on_import(self, op):
# order = 0
# for input in op.inputs:
# if input not in self.graph.inputs:
# order = max(order, self[input.owner] + 1)
# self[op] = order
# def on_prune(self, op):
# del self[op]
# def on_replace(self, v, new_v):
# self.__adjust__(new_v.owner, self.get(v.owner, 0))
class SuperFinder(EnvListener, dict):
def __init__(self, env, props):
self.env = env
self.props = props
def on_import(self, op):
for prop, value in self.props(op).items():
self.setdefault(prop, {}).setdefault(value, set()).add(op)
def on_prune(self, op):
for prop, value in self.props(op).items():
self[prop][value].remove(op)
if len(self[prop][value]) == 0:
del self[prop][value]
if len(self[prop]) == 0:
del self[prop]
def __query__(self, order, template):
all = []
for prop, value in template.items():
all += [x for x in self.get(prop, {}).get(value, set())]
# If not None, the order option requires the order listener to be included in the env under the name 'order'
if order == 'o->i':
all.sort(lambda op1, op2: self.env.order[op1].__cmp__(self.env.order[op2]))
elif order == 'i->o':
all.sort(lambda op1, op2: self.env.order[op2].__cmp__(self.env.order[op1]))
while all:
next = all.pop()
if next in self.env.ops():
yield next
def query(self, **template):
return self.__query__(None, template)
def query_downstream(self, **template):
return self.__query__('i->o', template)
def query_upstream(self, **template):
return self.__query__('o->i', template)
# class DupListener(EnvListener):
# def __init__(self, env):
# self.to_cid = {}
# self.to_obj = {}
# self.env = env
# for i, input in enumerate(env.inputs):
# self.to_cid[input] = i
# self.to_obj[i] = input
# def init(self, env):
# self.env = env
# for i, input in enumerate(env.inputs):
# self.to_cid[input] = i
# self.to_obj[i] = input
# def on_import(self, op):
# cid = (op.__class__, tuple([self.to_cid[input] for input in op.inputs]))
# self.to_cid[op] = cid
# self.to_obj.setdefault(cid, op)
# for i, output in enumerate(op.outputs):
# ocid = (i, cid)
# self.to_cid[output] = ocid
# self.to_obj.setdefault(ocid, output)
# def on_prune(self, op):
# # we don't delete anything
# return
# def apply(self, env):
# if env is not self.env:
# raise Exception("The DupListener merge optimization can only apply to the env it is listening to.")
# def fn(op):
# op2 = self.to_obj[self.to_cid[op]]
# if op is not op2:
# return [(o, o2) for o, o2 in zip(op.outputs, op2.outputs)]
# env.walk_from_outputs(fn)
# def __call__(self):
# self.apply(self.env)
class DestroyHandler(EnvListener):
def __init__(self, env):
self.parent = {}
self.children = {}
self.destroyers = {}
self.paths = {}
self.dups = set()
self.cycles = set()
self.env = env
for input in env.inputs:
self.parent[input] = None
self.children[input] = set()
def __path__(self, r):
path = self.paths.get(r, None)
if path:
return path
rval = [r]
r = self.parent[r]
while r:
rval.append(r)
r = self.parent[r]
rval.reverse()
for i, x in enumerate(rval):
self.paths[x] = rval[0:i+1]
return rval
def __views__(self, r):
children = self.children[r]
if not children:
return set([r])
else:
rval = set([r])
for child in children:
rval.update(self.__views__(child))
return rval
def __users__(self, r):
views = self.__views__(r)
rval = set()
for view in views:
for op, i in self.env.clients(view):
rval.update(op.outputs)
return rval
def __pre__(self, op):
rval = set()
if op is None:
return rval
keep_going = False
for input in op.inputs:
foundation = self.__path__(input)[0]
destroyers = self.destroyers.get(foundation, set())
if destroyers:
keep_going = True
if op in destroyers:
users = self.__users__(foundation)
rval.update(users)
if not keep_going:
return set()
rval.update(op.inputs)
rval.difference_update(op.outputs)
return rval
def __detect_cycles_helper__(self, r, seq):
if r in seq:
self.cycles.add(tuple(seq[seq.index(r):]))
return
pre = self.__pre__(r.owner)
for r2 in pre:
self.__detect_cycles_helper__(r2, seq + [r])
def __detect_cycles__(self, start, just_remove=False):
users = self.__users__(start)
users.add(start)
for user in users:
for cycle in copy(self.cycles):
if user in cycle:
self.cycles.remove(cycle)
if just_remove:
return
for user in users:
self.__detect_cycles_helper__(user, [])
def get_maps(self, op):
return getattr(op, 'view_map', lambda x:{})(), \
getattr(op, 'destroy_map', lambda x:{})()
def on_import(self, op):
view_map, destroy_map = self.get_maps(op)
for input in op.inputs:
self.parent.setdefault(input, None)
for i, output in enumerate(op.outputs):
views = view_map.get(output, None)
destroyed = destroy_map.get(output, None)
if destroyed:
self.parent[output] = None
for input in destroyed:
path = self.__path__(input)
self.__add_destroyer__(path + [output])
elif views:
if len(inputs) > 1:
raise Exception("Output is a view of too many inputs.")
self.parent[output] = inputs[0]
for input in views:
self.children[input].add(output)
else:
self.parent[output] = None
self.children[output] = set()
for output in op.outputs:
self.__detect_cycles__(output)
# if destroy_map:
# print "op: ", op
# print "ord: ", [str(x) for x in self.orderings()[op]]
# print
def on_prune(self, op):
view_map, destroy_map = self.get_maps(op)
if destroy_map:
destroyers = []
for i, input in enumerate(op.inputs):
destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
for destroyer in destroyers:
path = destroyer.get(op, [])
if path:
self.__remove_destroyer__(path)
if view_map:
for i, input in enumerate(op.inputs):
self.children[input].difference_update(op.outputs)
for output in op.outputs:
try:
del self.paths[output]
except:
pass
self.__detect_cycles__(output, True)
for i, output in enumerate(op.outputs):
del self.parent[output]
del self.children[output]
def __add_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(op, path)
if len(destroyers) > 1:
self.dups.add(foundation)
def __remove_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers[foundation]
del destroyers[op]
if not destroyers:
del self.destroyers[foundation]
elif len(destroyers) == 1 and foundation in self.dups:
self.dups.remove(foundation)
def on_rewire(self, clients, r_1, r_2):
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
prev = set()
for op, i in clients:
prev.update(op.outputs)
foundation = path_1[0]
destroyers = self.destroyers.get(foundation, {}).items()
for op, path in destroyers:
if r_1 in path:
idx = path.index(r_1)
self.__remove_destroyer__(path)
if not (idx > 0 and path[idx - 1] in prev):
continue
index = path.index(r_1)
new_path = path_2 + path[index+1:]
self.__add_destroyer__(new_path)
for op, i in clients:
view_map, _ = self.get_maps(op)
for output, inputs in view_map.items():
if r_1 in inputs:
assert self.parent[output] == r_1
self.parent[output] = r_2
self.children[r_1].remove(output)
self.children[r_2].add(output)
for view in self.__views__(r_1):
try:
del self.paths[view]
except:
pass
for view in self.__views__(r_2):
try:
del self.paths[view]
except:
pass
self.__detect_cycles__(r_1)
self.__detect_cycles__(r_2)
def validate(self):
if self.dups:
raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
elif self.cycles:
raise InconsistencyError("There are cycles: %s" % self.cycles)
else:
return True
def orderings(self):
ords = {}
for foundation, destroyers in self.destroyers.items():
for op in destroyers.keys():
ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
return ords
"""
Contains the Op class, which is the base interface for all operations
compatible with gof's graph manipulation routines.
"""
from result import Result, BrokenLink
from utils import ClsInit, all_bases, all_bases_collect
from copy import copy
__all__ = ['Op']
class Op(object):
"""
Op represents a computation on the storage in its 'inputs' slot,
the results of which are stored in the Result instances in the
'outputs' slot. The owner of each Result in the outputs list must
be set to this Op and thus any Result instance is in the outputs
list of at most one Op, its owner. It is the responsibility of the
Op to ensure that it owns its outputs and it is encouraged (though
not required) that it creates them.
After construction, self.inputs and self.outputs should only be
modified through the set_input and set_output methods.
"""
__slots__ = ['_inputs', '_outputs']
__require__ = []
inputs = property(lambda self: self._inputs, doc = "The list of this Op's input Results.")
outputs = property(lambda self: self._outputs, doc = "The list of this Op's output Results.")
"""
If true, self.default_output() or self.out can be used to access
self.outputs[0]
"""
has_default_output = True
out = property(lambda self: self.default_output(), doc = "Same as self.outputs[0] if this Op's has_default_output field is True.")
def __init__(self, inputs, outputs, use_self_setters = False):
"""
Initializes the '_inputs' and '_outputs' slots and sets the
owner of all outputs to self.
If use_self_setters is False, Op::set_input and Op::set_output
are used, which do the minimum checks and manipulations. Else,
the user defined set_input and set_output functions are
called (in any case, all inputs and outputs are initialized
to None).
"""
self._inputs = [None] * len(inputs)
self._outputs = [None] * len(outputs)
if use_self_setters:
for i, input in enumerate(inputs):
self.set_input(i, input, validate = False)
for i, output in enumerate(outputs):
self.set_output(i, output, validate = False)
self.validate()
else:
for i, input in enumerate(inputs):
Op.set_input(self, i, input, validate = False)
for i, output in enumerate(outputs):
Op.set_output(self, i, output, validate = False)
self.validate()
self.validate()
def default_output(self):
"""
Returns the default output of this Op instance, typically self.outputs[0].
"""
if self.has_default_output:
return self.outputs[0]
else:
raise AttributeError("Op does not have a default output.")
def set_input(self, i, input, allow_changes = False, validate = True):
"""
Sets the ith input of self.inputs to input. i must be an
integer in the range from 0 to len(self.inputs) - 1 and input
must be a Result instance. The method may raise a GofTypeError
or a GofValueError accordingly to the semantics of the Op, if
the new input is of the wrong type or has the wrong
properties.
If i > len(self.inputs), an IndexError must be raised. If i ==
len(self.inputs), it is allowed for the Op to extend the list
of inputs if it is a vararg Op, else an IndexError should be
raised.
For a vararg Op, it is also allowed to have the input
parameter set to None for 0 <= i < len(self.inputs), in which
case the rest of the inputs will be shifted left. In any other
situation, a ValueError should be raised.
In some cases, set_input may change some outputs: for example,
a change of an input from float to double might require the
output's type to also change from float to double. If
allow_changes is True, set_input is allowed to perform those
changes and must return a list of pairs, each pair containing
the old output and the output it was replaced with (they
_must_ be different Result instances). See Op::set_output for
important information about replacing outputs. If
allow_changes is False and some change in the outputs is
required for the change in input to be correct, a
PropagationError must be raised.
This default implementation sets the ith input to input and
changes no outputs. It returns None.
"""
previous = self.inputs[i]
self.inputs[i] = input
if validate:
try:
self.validate()
except:
self.set_input(i, previous, True, False)
def set_output(self, i, output, validate = True):
"""
Sets the ith output to output. The previous output, which is
being replaced, must be invalidated using Result::invalidate.
The new output must not already have an owner, or its owner must
be self. It cannot be a broken link, unless it used to be at this
spot, in which case it can be reinstated.
For Ops that have vararg output lists, see the regulations in
Op::set_input.
"""
if isinstance(output.owner, BrokenLink) \
and output.owner.owner is self \
and output.owner.index == i:
output.revalidate()
else:
output.set_owner(self, i) # this checks for an already existing owner
previous = self.outputs[i]
if previous:
previous.invalidate()
self.outputs[i] = output
if validate:
try:
self.validate()
except:
self.set_output(i, previous, False)
def refresh(self, allow_changes = False):
"""
This function attempts to repair all inputs that are broken
links by calling set_input on the new Result that replaced
them. Note that if a set_input operation invalidates one or
more outputs, new broken links might appear in the other ops
that use this op's outputs.
It is possible that the new inputs are inconsistent with this
op, in which case an exception will be raised and the previous
inputs (and outputs) will be restored.
refresh returns a list of (old_output, new_output) pairs
detailing the changes, if any.
"""
backtrack = []
try:
for i, input in enumerate(self.inputs):
link = input.owner
if isinstance(link, BrokenLink):
current = link.owner.outputs[link.index]
dirt = self.set_input(i, current, allow_changes)
backtrack.append((i, input, dirt))
except:
# Restore the inputs and outputs that were successfully changed.
for i, input, dirt in backtrack:
self.inputs[i] = input
if dirt:
for old, new in dirt:
new.invalidate()
old.revalidate()
self.outputs[self.outputs.index(new)] = old
raise
all_dirt = []
for i, input, dirt in backtrack:
if dirt:
all_dirt += dirt
return all_dirt
def perform(self):
"""
Performs the computation on the inputs and stores the results
in the outputs. This function should check for the validity of
the inputs and raise appropriate errors for debugging (for
executing without checks, override _perform).
An Op may define additional ways to perform the computation
that are more efficient (e.g. a piece of C code or a C struct
with direct references to the inputs and outputs), but
perform() should always be available in order to have a
consistent interface to execute graphs.
"""
raise NotImplementedError
def _perform(self):
"""
Performs the computation on the inputs and stores the results
in the outputs, like perform(), but is not required to check
the existence or the validity of the inputs.
"""
return self.perform()
@classmethod
def require(cls):
"""
Returns a set of Feature subclasses that must be used by any
Env manipulating this kind of op. For instance, a Destroyer
requires ext.DestroyHandler to guarantee that various
destructive operations don't interfere.
By default, this collates the __require__ field of this class
and the __require__ fields of all classes that are directly or
indirectly superclasses to this class into a set.
"""
r = set()
bases = all_bases(cls, lambda cls: hasattr(cls, '__require__'))
bases.append(cls)
for base in bases:
req = base.__require__
if isinstance(req, (list, tuple)):
r.update(req)
else:
r.add(req)
return r
def validate(self):
"""
This class's __validate__ function will be called, as well as
the __validate__ functions of all base classes down the class
tree. If you do not want to execute __validate__ from the base
classes, set the class variable __validate_override__ to True.
"""
vfns = all_bases_collect(self.__class__, 'validate')
for vfn in vfns:
vfn(self)
def __copy__(self):
"""
Copies the inputs list shallowly and copies all the outputs
because of the one owner per output restriction.
"""
new_inputs = copy(self.inputs)
# We copy the outputs because they are tied to a single Op.
new_outputs = [copy(output) for output in self.outputs]
op = self.__class__(new_inputs, new_outputs)
op._inputs = new_inputs
op._outputs = new_outputs
for i, output in enumerate(op.outputs):
# We adjust _owner and _index manually since the copies
# point to the previous op (self).
output._owner = op
output._index = i
return op
def __deepcopy__(self, memo):
"""
Not implemented. Use gof.graph.clone(inputs, outputs) to copy
a subgraph.
"""
raise NotImplementedError("Use gof.graph.clone(inputs, outputs) to copy a subgraph.")
from op import Op
from env import InconsistencyError
import utils
import unify
import features
import ext
class Optimizer:
__require__ = ()
def apply(self, env):
pass
def optimize(self, env):
env.satisfy(self)
self.apply(env)
@classmethod
def require(cls):
"""
Returns a list of EnvFeature subclasses that must be used by
any Env manipulating this kind of op. For instance, a
Destroyer requires features.DestroyHandler to guarantee that
various destructive operations don't interfere.
"""
r = set()
bases = utils.all_bases(cls, lambda cls: hasattr(cls, '__require__'))
bases.append(cls)
for base in bases:
req = base.__require__
if isinstance(req, (list, tuple)):
r.update(req)
else:
r.add(req)
return r
def __call__(self, env):
self.optimize(env)
DummyOpt = Optimizer()
class SeqOptimizer(Optimizer, list):
def apply(self, env):
for optimizer in self:
optimizer.optimize(env)
def __str__(self):
return "SeqOpt(%s)" % list.__str__(self)
def __repr__(self):
return "SeqOpt(%s)" % list.__repr__(self)
class LocalOptimizer(Optimizer):
def candidates(self, env):
return env.ops()
def apply_on_op(self, env, op):
raise Exception("Please override this function.")
def apply(self, env):
for op in self.candidates(env):
if env.has_op(op):
self.apply_on_op(env, op)
# no_change_listener = graph.changed is None
# while(True):
# exprs = self.candidates(graph)
# graph.changed(False)
# for expr in exprs:
# self.apply_on_op(graph, expr)
# if no_change_listener or graph.changed:
# break
# else:
# break
class OpSpecificOptimizer(LocalOptimizer):
__require__ = features.InstanceFinder
opclass = Op
def candidates(self, env):
return env.get_instances_of(self.opclass)
class OpSubOptimizer(Optimizer):
__require__ = features.InstanceFinder
def __init__(self, op1, op2):
if not op1.has_default_output:
raise TypeError("OpSubOptimizer must be used with Op instances that have a default output.")
# note: op2 must have the same input signature as op1
self.op1 = op1
self.op2 = op2
def apply(self, env):
candidates = env.get_instances_of(self.op1)
for op in candidates:
try:
# note: only replaces the default 'out' port if it exists
r = self.op2(*op.inputs)
if isinstance(r, Op):
r = r.out
env.replace(op.out, r)
except InconsistencyError, e:
print "Warning: OpSubOpt failed to transform %s into %s: %s" % (op, self.op2, str(e)) # warning is for debug
pass
class OpRemover(Optimizer):
__require__ = features.InstanceFinder
def __init__(self, opclass):
self.opclass = opclass
def apply(self, env):
candidates = env.get_instances_of(self.opclass)
for op in candidates:
try:
assert len(op.inputs) == len(op.outputs)
for input, output in zip(op.inputs, op.outputs):
env.replace(output, input)
except InconsistencyError, e:
print "Warning: OpRemover failed to remove %s: %s" % (op, str(e)) # warning is for debug
pass
class PatternOptimizer(OpSpecificOptimizer):
"""
Replaces all occurrences of the first pattern by the second pattern.
"""
def __init__(self, in_pattern, out_pattern):
self.in_pattern = in_pattern
self.out_pattern = out_pattern
self.opclass = self.in_pattern[0]
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
def apply_on_op(self, env, op):
def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)):
if not issubclass(expr.owner.__class__, pattern[0]) or (not first and env.nclients(expr.owner) > 1):
return False
if len(pattern) - 1 != len(expr.owner.inputs):
return False
for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u)
if not u:
return False
elif isinstance(pattern, str):
v = unify.Var(pattern)
if u[v] is not v and u[v] is not expr:
return False
else:
u = u.merge(expr, v)
else:
if pattern != expr:
return False
return u
return u
def build(pattern, u):
if isinstance(pattern, (list, tuple)):
return pattern[0](*[build(p, u) for p in pattern[1:]])
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
else:
return pattern
u = match(self.in_pattern, op.out, unify.Unification(), True)
if u:
try:
# note: only replaces the default 'out' port if it exists
new = build(self.out_pattern, u)
if isinstance(new, Op):
new = new.out
env.replace(op.out, new)
except InconsistencyError, e:
print "Warning: '%s' failed to apply on %s: %s" % (self, op, str(e)) # warning is for debug
pass
def __str__(self):
def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (pattern[0].__name__, ", ".join([pattern_to_str(p) for p in pattern[1:]]))
else:
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
class MergeOptimizer(Optimizer):
def apply(self, env):
cid = {}
inv_cid = {}
for i, r in enumerate(env.inputs.union(env.orphans())):
cid[r] = i
inv_cid[i] = r
for op in env.io_toposort():
op_cid = (op.__class__, tuple([cid[input] for input in op.inputs]))
dup = inv_cid.get(op_cid, None)
if dup is None:
cid[op] = op_cid
inv_cid[op_cid] = op
for i, output in enumerate(op.outputs):
ref = (i, op_cid)
cid[output] = ref
inv_cid[ref] = output
else:
for output, other_output in zip(op.outputs, dup.outputs):
#print "replacing: %s %s" % (repr(output.owner), repr(other_output.owner))
env.replace(output, other_output)
def MergeOptMerge(opt):
merger = MergeOptimizer()
return SeqOptimizer([merger, opt, merger])
class MultiOptimizer(Optimizer):
def __init__(self, **opts):
self._opts = []
self.ord = {}
self.name_to_opt = {}
self.up_to_date = True
for name, opt in opts:
self.register(name, opt, after = [], before = [])
def register(self, name, opt, **relative):
self.name_to_opt[name] = opt
after = relative.get('after', [])
if not isinstance(after, (list, tuple)):
after = [after]
before = relative.get('before', [])
if not isinstance(before, (list, tuple)):
before = [before]
self.up_to_date = False
if name in self.ord:
raise Exception("Cannot redefine optimization: '%s'" % name)
self.ord[name] = set(after)
for postreq in before:
self.ord.setdefault(postreq, set()).add(name)
def get_opts(self):
if not self.up_to_date:
self.refresh()
return self._opts
def refresh(self):
self._opts = [self.name_to_opt[name] for name in utils.toposort(self.ord)]
self.up_to_date = True
def apply(self, env):
for opt in self.opts:
opt.apply(env)
opts = property(get_opts)
class TaggedMultiOptimizer(MultiOptimizer):
def __init__(self, **opts):
self.tags = {}
MultiOptimizer.__init__(self, **opts)
def register(self, name, opt, tags = [], **relative):
tags = set(tags)
tags.add(name)
self.tags[opt] = tags
MultiOptimizer.register(self, name, opt, **relative)
def filter(self, whitelist, blacklist):
return [opt for opt in self.opts
if self.tags[opt].intersection(whitelist)
and not self.tags[opt].intersection(blacklist)]
def whitelist(self, *tags):
return [opt for opt in self.opts if self.tags[opt].intersection(tags)]
def blacklist(self, *tags):
return [opt for opt in self.opts if not self.tags[opt].intersection(tags)]
class TagFilterMultiOptimizer(Optimizer):
def __init__(self, all, whitelist = None, blacklist = None):
self.all = all
if whitelist is not None:
self.whitelist = set(whitelist)
else:
self.whitelist = None
if blacklist is not None:
self.blacklist = set(blacklist)
else:
self.blacklist = set()
def use_whitelist(self, use = True):
if self.whitelist is None and use:
self.whitelist = set()
def allow(self, *tags):
if self.whitelist is not None:
self.whitelist.update(tags)
self.blacklist.difference_update(tags)
def deny(self, *tags):
if self.whitelist is not None:
self.whitelist.difference_update(tags)
self.blacklist.update(tags)
def dont_care(self, *tags):
if self.whitelist is not None:
self.whitelist.difference_update(tags)
self.blacklist.difference_update(tags)
def opts(self):
if self.whitelist is not None:
return self.all.filter(self.whitelist, self.blacklist)
else:
return self.all.blacklist(*[tag for tag in self.blacklist])
def apply(self, env):
for opt in self.opts():
opt.apply(env)
# import compile
import env
import link
from features import EquivTool
class Prog:
def __init__(self, inputs, outputs, optimizer, linker, features = []):
self.optimizer = optimizer
self.linker = linker
features = set(features)
features.add(EquivTool)
self.env = env.Env(inputs, outputs, features) #, False)
self.optimizer.optimize(self.env)
self.perform = self.linker(self.env)
self.outputs = outputs
# def __optimize__(self):
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
def equiv(self, r):
return self.env.equiv(r)
def __getitem__(self, r):
return self.equiv(r)
def __setitem__(self, r, value):
if isinstance(r, tuple):
for a, b in zip(r, value):
self.__setitem__(a, b)
else:
self.equiv(r).set_value(value)
def __call__(self, *args):
self.perform()
for output in self.outputs:
output.set_value(self[output])
return self.outputs
# return [output for output in self.env.outputs]
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
"""
Contains the Result class, which is the base interface for a
value that is the input or the output of an Op.
"""
from err import GofError
__all__ = ['Result', 'BrokenLink', 'BrokenLinkError']
class BrokenLink:
"""
This is placed as the owner of a Result that was replaced by
another Result.
"""
__slots__ = ['owner', 'index']
def __init__(self, owner, index):
self.owner = owner
self.index = index
def __nonzero__(self):
return False
class BrokenLinkError(GofError):
"""
"""
pass
############################
# Result
############################
class Result(object):
"""
The Result class represents a datum for use in a graph of Ops. It
has two slots:
- owner: represents the Op which computes this Result. It is
assumed to be an instance of Op. If owner raises an
AttributeError, the Result is assumed to be an input.
- index: the index this Result holds in its owner's outputs.
Result has no __init__ or __new__ routine. It is the Op's
responsibility to set the owner field of its results.
The Result class is abstract. It must be subclassed to support the
types of data needed for computation.
A Result instance should be immutable: indeed, if some aspect of a
Result is changed, operations that use it might suddenly become
invalid. Instead, a new Result instance should be instanciated
with the correct properties and the invalidate method should be
called on the Result which is replaced (this will make its owner a
BrokenLink instance, which behaves like False in conditional
expressions).
"""
__slots__ = ['_owner', '_index']
def get_owner(self):
if not hasattr(self, '_owner'):
self._owner = None
return self._owner
owner = property(get_owner, doc = "The Op of which this Result is an output or None if there is no such Op.")
def set_owner(self, owner, index):
if self.owner is not None:
if self.owner is not owner:
raise ValueError("Result %s already has an owner." % self)
elif self.index != index:
raise ValueError("Result %s was already mapped to a different index." % self)
self._owner = owner
self._index = index
def invalidate(self):
if self.owner is None:
raise Exception("Cannot invalidate a Result instance with no owner.")
elif not isinstance(self.owner, BrokenLink):
self._owner = BrokenLink(self._owner, self._index)
del self._index
def revalidate(self):
if isinstance(self.owner, BrokenLink):
owner, index = self._owner.owner, self._owner.index
self._owner = owner
self._index = index
def set_value(self, value):
"""
Copies the provided value in this Result. It is not required to
implement this method.
"""
raise NotImplementedError("This Result does not support set_value.")
# def extract(self):
# """
# Returns a representation of this datum for use in Op.impl.
# Successive calls to extract should always return the same object.
# """
# raise NotImplementedError
# def sync(self):
# """
# After calling Op.impl, synchronizes the Result instance with the
# new contents of the storage. This might usually not be necessary.
# """
# raise NotImplementedError
# def c_libs(self):
# """
# Returns a list of libraries that must be included to work with
# this Result.
# """
# raise NotImplementedError
# def c_imports(self):
# """
# Returns a list of strings representing headers to import when
# building a C interface that uses this Result.
# """
# raise NotImplementedError
# def c_declare(self):
# """
# Returns code which declares and initializes a C variable in
# which this Result can be held.
# """
# raise NotImplementedError
# def pyo_to_c(self):
# raise NotImplementedError
# def c_to_pyo(self):
# raise NotImplementedError
############################
# Utilities
############################
# class SelfContainedResult(Result):
# """
# This represents a Result which acts as its own data container. It
# is recommended to subclass this if you wish to be able to use the
# Result in normal computations as well as working with a graph
# representation.
# """
# # def extract(self):
# # """Returns self."""
# # return self
# # def sync(self):
# # """Does nothing."""
# # pass
# class HolderResult(Result):
# """
# HolderResult adds a 'data' slot which is meant to contain the
# object used by the Op implementation. It is recommended to subclass
# this if you want to be able to use the exact same object at
# different points in a computation.
# """
# __slots__ = ['data']
# # def extract(self):
# # """Returns self.data."""
# # return self.data
# # def sync(self):
# # """
# # Does nothing. Override if you have additional fields or
# # functionality in your subclass which need to be computed from
# # the data.
# # """
# # pass
from copy import copy
from utils import *
################################
class Variable:
"""
Serves as a base class of variables for the purpose of unification.
Behavior for unifying various types of variables should be added as
overloadings of the 'unify' function.
"""
def __init__(self, name = "?"):
self.name = name
def __str__(self):
return self.__class__.__name__ + "(" + ", ".join(["%s=%s" % (key, value) for key, value in self.__dict__.items()]) + ")"
def __repr__(self):
return str(self)
class FreeVariable(Variable):
"""
This Variable can take any value.
"""
pass
class BoundVariable(Variable):
"""
This Variable is bound to a value accessible via the value field.
"""
def __init__(self, name, value):
self.name = name
self.value = value
class OrVariable(Variable):
"""
This Variable could be any value from a finite list of values,
accessible via the options field.
"""
def __init__(self, name, options):
self.name = name
self.options = options
class NotVariable(Variable):
"""
This Variable can take any value but a finite amount of forbidden
values, accessible via the not_options field.
"""
def __init__(self, name, not_options):
self.name = name
self.not_options = not_options
class VariableInList: # not a subclass of Variable
"""
This special kind of variable is matched against a list and unifies
an inner Variable to an OrVariable of the values in the list. For
example, if we unify VariableInList(FreeVariable('x')) to [1,2,3],
the 'x' variable is unified to an OrVariable('?', [1,2,3]).
"""
def __init__(self, variable):
self.variable = variable
################################
_all = {}
def var_lookup(vartype, name, *args, **kwargs):
sig = (vartype, name)
if sig in _all:
return _all[sig]
else:
v = vartype(name, *args)
_all[sig] = v
return v
Var = partial(var_lookup, FreeVariable)
V = Var
OrV = partial(var_lookup, OrVariable)
NV = partial(var_lookup, NotVariable)
################################
class Unification:
"""
This class represents a possible unification of a group of variables
with each other or with tangible values.
"""
def __init__(self, inplace = False):
"""
If inplace is False, the merge method will return a new Unification
that is independant from the previous one (which allows backtracking).
"""
self.unif = {}
self.inplace = inplace
def merge(self, new_best, *vars):
"""
Links all the specified vars to a Variable that represents their
unification.
"""
if self.inplace:
U = self
else:
# Copy all the unification data.
U = Unification(self.inplace)
for var, (best, pool) in self.unif.items():
# The pool of a variable is the set of all the variables that
# are unified to it (all the variables that must have the same
# value). The best is the Variable that represents a set of
# values common to all the variables in the pool.
U.unif[var] = (best, pool)
# We create a new pool for our new set of unified variables, initially
# containing vars and new_best
new_pool = set(vars)
new_pool.add(new_best)
for var in copy(new_pool):
best, pool = U.unif.get(var, (var, set()))
# We now extend the new pool to contain the pools of all the variables.
new_pool.update(pool)
# All variables get the new pool.
for var in new_pool:
U.unif[var] = (new_best, new_pool)
return U
def __getitem__(self, v):
"""
For a variable v, returns a Variable that represents the tightest
set of possible values it can take.
"""
return self.unif.get(v, (v, None))[0]
################################
def unify_walk(a, b, U):
"""
unify_walk(a, b, U) returns an Unification where a and b are unified, given the
unification that already exists in the Unification U. If the unification fails,
it returns False.
There are two ways to expand the functionality of unify_walk. The first way is:
@comm_guard(type_of_a, type_of_b)
def unify_walk(a, b, U):
...
A function defined as such will be executed whenever the types of a and b
match the declaration. Note that comm_guard automatically guarantees that
your function is commutative: it will try to match the types of a, b or b, a.
It is recommended to define unify_walk in that fashion for new types of Variable
because different types of Variable interact a lot with each other, e.g.
when unifying an OrVariable with a NotVariable, etc. You can return the
special marker FALL_THROUGH to indicate that you want to relay execution
to the next match of the type signature. The definitions of unify_walk are tried
in the reverse order of their declaration.
Another way is to override __unify_walk__ in an user-defined class.
Limitations: cannot embed a Variable in another (the functionality could
be added if required)
Here is a list of unification rules with their associated behavior:
"""
if a.__class__ != b.__class__:
return False
elif a == b:
return U
else:
return False
@comm_guard(FreeVariable, ANY_TYPE)
def unify_walk(fv, o, U):
"""
FreeV is unified to BoundVariable(other_object)
"""
v = BoundVariable("?", o)
return U.merge(v, fv)
@comm_guard(BoundVariable, ANY_TYPE)
def unify_walk(bv, o, U):
"""
The unification succeed iff BV.value == other_object
"""
if bv.value == o:
return U
else:
return False
@comm_guard(OrVariable, ANY_TYPE)
def unify_walk(ov, o, U):
"""
The unification succeeds iff other_object in OrV.options
"""
if o in ov.options:
v = BoundVariable("?", o)
return U.merge(v, ov)
else:
return False
@comm_guard(NotVariable, ANY_TYPE)
def unify_walk(nv, o, U):
"""
The unification succeeds iff other_object not in NV.not_options
"""
if o in nv.not_options:
return False
else:
v = BoundVariable("?", o)
return U.merge(v, nv)
@comm_guard(FreeVariable, Variable)
def unify_walk(fv, v, U):
"""
Both variables are unified.
"""
v = U[v]
return U.merge(v, fv)
@comm_guard(BoundVariable, Variable)
def unify_walk(bv, v, U):
"""
V is unified to BV.value
"""
return unify_walk(v, bv.value, U)
@comm_guard(OrVariable, OrVariable)
def unify_walk(a, b, U):
"""
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
"""
opt = intersection(a.options, b.options)
if not opt:
return False
elif len(opt) == 1:
v = BoundVariable("?", opt[0])
else:
v = OrVariable("?", opt)
return U.merge(v, a, b)
@comm_guard(NotVariable, NotVariable)
def unify_walk(a, b, U):
"""
NV(list1) == NV(list2) == NV(union(list1, list2))
"""
opt = union(a.not_options, b.not_options)
v = NotVariable("?", opt)
return U.merge(v, a, b)
@comm_guard(OrVariable, NotVariable)
def unify_walk(o, n, U):
"""
OrV(list1) == NV(list2) == OrV(list1 \ list2)
"""
opt = [x for x in o.options if x not in n.not_options]
if not opt:
return False
elif len(opt) == 1:
v = BoundVariable("?", opt[0])
else:
v = OrVariable("?", opt)
return U.merge(v, o, n)
@comm_guard(VariableInList, (list, tuple))
def unify_walk(vil, l, U):
"""
Unifies VIL's inner Variable to OrV(list).
"""
v = vil.variable
ov = OrVariable("?", l)
return unify_walk(v, ov, U)
@comm_guard((list, tuple), (list, tuple))
def unify_walk(l1, l2, U):
"""
Tries to unify each corresponding pair of elements from l1 and l2.
"""
if len(l1) != len(l2):
return False
for x1, x2 in zip(l1, l2):
U = unify_walk(x1, x2, U)
if U is False:
return False
return U
@comm_guard(dict, dict)
def unify_walk(d1, d2, U):
"""
Tries to unify values of corresponding keys.
"""
for (k1, v1) in d1.items():
if d2.has_key(k1):
U = unify_walk(v1, d2[k1], U)
if U is False:
return False
return U
@comm_guard(ANY_TYPE, ANY_TYPE)
def unify_walk(a, b, U):
"""
Checks for the existence of the __unify_walk__ method for one of
the objects.
"""
if not isinstance(a, Variable) and not isinstance(b, Variable) \
and hasattr(a, "__unify_walk__"):
return a.__unify_walk__(b, U)
else:
return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE)
def unify_walk(v, o, U):
"""
This simply checks if the Var has an unification in U and uses it
instead of the Var. If the Var is already its tighest unification,
falls through.
"""
best_v = U[v]
if v is not best_v:
return unify_walk(o, best_v, U) # reverse argument order so if o is a Variable this block of code is run again
else:
return FALL_THROUGH # call the next version of unify_walk that matches the type signature
################################
class FVar:
def __init__(self, fn, *args):
self.fn = fn
self.args = args
def __call__(self, u):
return self.fn(*[unify_build(arg, u) for arg in self.args])
################################
def unify_merge(a, b, U):
return a
@comm_guard(Variable, ANY_TYPE)
def unify_merge(v, o, U):
return v
@comm_guard(BoundVariable, ANY_TYPE)
def unify_merge(bv, o, U):
return bv.value
@comm_guard(VariableInList, (list, tuple))
def unify_merge(vil, l, U):
return [unify_merge(x,x,U) for x in l]
@comm_guard((list, tuple), (list, tuple))
def unify_merge(l1, l2, U):
return [unify_merge(x1, x2, U) for x1, x2 in zip(l1, l2)]
@comm_guard(dict, dict)
def unify_merge(d1, d2, U):
d = d1.__class__()
for k1, v1 in d1.items():
if d2.has_key(k1):
d[k1] = unify_merge(v1, d2[k1], U)
else:
d[k1] = unify_merge(v1, v1, U)
for k2, v2 in d2.items():
if not d1.has_key(k2):
d[k2] = unify_merge(v2, v2, U)
return d
@comm_guard(FVar, ANY_TYPE)
def unify_merge(vs, o, U):
return vs(U)
@comm_guard(ANY_TYPE, ANY_TYPE)
def unify_merge(a, b, U):
if not isinstance(a, Variable) and not isinstance(b, Variable) \
and hasattr(a, "__unify_merge__"):
return a.__unify_merge__(b, U)
else:
return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE)
def unify_merge(v, o, U):
"""
This simply checks if the Var has an unification in U and uses it
instead of the Var. If the Var is already its tighest unification,
falls through.
"""
best_v = U[v]
if v is not best_v:
return unify_merge(o, best_v, U) # reverse argument order so if o is a Variable this block of code is run again
else:
return FALL_THROUGH # call the next version of unify_walk that matches the type signature
################################
def unify_build(x, U):
return unify_merge(x, x, U)
################################
def unify(a, b):
U = unify_walk(a, b, Unification())
if not U:
return None, False
else:
return unify_merge(a, b, U), U
################################
if __name__ == "__main__":
vx = NotVariable("x", ["big", "bones"])
vy = OrVariable("y", ["hello", "big"])
vz = V("z")
va = V("a")
vl = VariableInList(vz)
pattern1 = dict(hey=vx, ulala=va, a=1)
pattern2 = dict(hey=vy, ulala=10, b=2)
# pattern1 = ["hello", "big", "bones"]
# pattern2 = vl
# pattern1 = [vx]#, "big", "bones"]
# pattern2 = [vy]#, vy, vz]
U = unify_walk(pattern1, pattern2, Unification())
if U:
print U[va]
print U[vx]
print U[vy]
print U[vz]
print unify_merge(pattern1, pattern2, U)
else:
print "no match"
U = unify_walk((1, 2), (va, va), Unification())
print U[va]
# import op
# import result
class OmegaError(Exception):
pass
def all_bases(cls, accept):
rval = set([cls])
for base in cls.__bases__:
rval.update(all_bases(base, accept))
return [cls for cls in rval if accept(cls)]
def all_bases_collect(cls, raw_name):
rval = set()
name = "__%s__" % raw_name
if name in cls.__dict__: # don't use hasattr
rval.add(getattr(cls, name))
cut = "__%s_override__" % raw_name
if not cls.__dict__.get(cut, False):
for base in cls.__bases__:
rval.update(all_bases_collect(base, raw_name))
return rval
def uniq_features(_features, *_rest):
features = [x for x in _features]
for other in _rest:
features += [x for x in other]
res = []
while features:
feature = features.pop()
for feature2 in features:
if issubclass(feature2, feature):
break
else:
res.append(feature)
return res
def partial(func, *args, **keywords):
def newfunc(*fargs, **fkeywords):
newkeywords = keywords.copy()
newkeywords.update(fkeywords)
return func(*(args + fargs), **newkeywords)
newfunc.func = func
newfunc.args = args
newfunc.keywords = keywords
return newfunc
class ClsInit(type):
"""Class initializer for Op subclasses"""
def __init__(cls, name, bases, dct):
"""Validate and initialize the Op subclass 'cls'
This function:
- changes class attributes input_names and output_names to be lists if they are single strings.
"""
type.__init__(cls, name, bases, dct)
cls.__clsinit__(cls, name, bases, dct)
def toposort(prereqs_d):
"""
Sorts prereqs_d.keys() topologically. prereqs_d[x] contains all the elements
that must come before x in the ordering.
"""
# all1 = set(prereqs_d.keys())
# all2 = set()
# for x, y in prereqs_d.items():
# all2.update(y)
# print all1.difference(all2)
seq = []
done = set()
postreqs_d = {}
for x, prereqs in prereqs_d.items():
for prereq in prereqs:
postreqs_d.setdefault(prereq, set()).add(x)
next = set(k for k in prereqs_d if not prereqs_d[k])
while next:
bases = next
next = set()
for x in bases:
done.add(x)
seq.append(x)
for x in bases:
for postreq in postreqs_d.get(x, []):
if not prereqs_d[postreq].difference(done):
next.add(postreq)
if len(prereqs_d) != len(seq):
raise Exception("Cannot sort topologically: there might be cycles, " + \
"prereqs_d does not have a key for each element or " + \
"some orderings contain invalid elements.")
return seq
# def schedule(**kwargs):
# after = kwargs.get('after', [])
# if not isinstance(after, (list, tuple)):
# after = [after]
# before = kwargs.get('before', [])
# if not isinstance(before, (list, tuple)):
# before = [before]
# def decorate(fn):
# name = fn.__name__
# fn.prereqs_d = {}
# for postreq in after:
# prereqs_d[postreq] = name
# for prereq in before:
# prereqs_d[name] = prereq
# return fn
# return decorate
# def after(*others):
# return schedule(after = others)
# def before(*others):
# return schedule(before = others)
# class TopoList(list):
# def add(self, item, **kwargs):
# after = kwargs.get('after', [])
# if not isinstance(after, (list, tuple)):
# after = [after]
# before = kwargs.get('before', [])
# if not isinstance(before, (list, tuple)):
# before = [before]
class Keyword:
def __init__(self, name, nonzero=True):
self.name = name
self.nonzero = nonzero
def __nonzero__(self):
return self.nonzero
def __str__(self):
return "<%s>" % self.name
def __repr__(self):
return "<%s>" % self.name
ABORT = Keyword("ABORT", False)
RETRY = Keyword("RETRY", False)
FAILURE = Keyword("FAILURE", False)
simple_types = (int, float, str, bool, None.__class__, Keyword)
ANY_TYPE = Keyword("ANY_TYPE")
FALL_THROUGH = Keyword("FALL_THROUGH")
def comm_guard(type1, type2):
def wrap(f):
old_f = f.func_globals[f.__name__]
def new_f(arg1, arg2, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)) \
and (type2 is ANY_TYPE or isinstance(arg2, type2)):
pass
elif (type1 is ANY_TYPE or isinstance(arg2, type1)) \
and (type2 is ANY_TYPE or isinstance(arg1, type2)):
arg1, arg2 = arg2, arg1
else:
try:
return old_f(arg1, arg2, *rest)
except:
raise
try:
result = f(arg1, arg2, *rest)
except:
raise
if result is FALL_THROUGH:
try:
return old_f(arg1, arg2, *rest)
except:
raise
else:
return result
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
elif isinstance(type, (tuple, list)):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = str(old_f.__doc__) + "\n" + ", ".join([typename(type) for type in (type1, type2)]) + "\n" + str(f.__doc__ or "")
return new_f
return wrap
def type_guard(type1):
def wrap(f):
old_f = f.func_globals[f.__name__]
def new_f(arg1, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)):
result = f(arg1, *rest)
if result is FALL_THROUGH:
return old_f(arg1, *rest)
else:
return result
else:
return old_f(arg1, *rest)
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
elif isinstance(type, (tuple, list)):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = str(old_f.__doc__) + "\n" + ", ".join([typename(type) for type in (type1,)]) + "\n" + str(f.__doc__ or "")
return new_f
return wrap
#ifndef _OMEGA_H
#define _OMEGA_H
//#include whatever defines PyArrayObject
template<typename T>
struct TMat_t
{
T * __restrict__ d;/**< pointer to element (0,0) */
size_t M; /**< number of rows */
size_t N; /**< number of columns */
size_t m; /**< row stride */
size_t n; /**< column stride */
bool invalid;
/** null */
TMat_t(const PyArrayObject *o) :
d((double*) o->data),
M((o->nd==2) ? o->dimensions[0] : 0),
N((o->nd==2) ? o->dimensions[1] : 0),
m((o->nd==2) ? o->strides[0] / sizeof(double) : 0),
n((o->nd==2) ? o->strides[1] / sizeof(double) : 0),
invalid((o->nd !=2) || (o->descr->elsize != sizeof(T)))
{
}
/** unsafe element access */
const T & operator()(size_t i, size_t j) const
{
return d[ i * m + j*n];
}
/** unsafe element access */
T & operator()(size_t i, size_t j)
{
return d[ i * m + j*n];
}
/** safe element access */
const T & at(size_t i, size_t j) const
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
/** safe element access */
T & at(size_t i, size_t j)
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
};
#endif
from scipy.weave import c_spec, standard_array_spec
class omega_type_converter_extension:
def provides(self):
"""
Returns a list of (c_type, name, init_code) tuples that represent variables
the type converter provides to the user's code.
"""
tvars = self.template_vars()
return [(tvars['c_type'], tvars['name'], tvars['var_convert'])]
def format_provide(self, x):
return '%s %s = %s;\n' % x
def declaration_code(self, templatize = 0, inline = 0):
tvars = self.template_vars(inline=inline)
code = '%(py_var)s = %(var_lookup)s;\n' % tvars
code += ''.join([self.format_provide(export) for export in self.provides()])
return code
def struct_init_code(self):
return "Py_INCREF(py_%s);" % self.name
def struct_cleanup_code(self):
return "Py_DECREF(py_%s);" % self.name
def struct_members_code(self):
res = "PyObject* py_%s;\n" % self.name
return res + ''.join(['%s_type %s;\n' % (name, name) for c_type, name, init in self.provides()])
def struct_import_code(self):
res = "__STRUCT_P->py_%s = py_%s;\n" % (self.name, self.name)
return res + ''.join(['__STRUCT_P->%s = %s;\n' % (name, name) for c_type, name, init in self.provides()])
def struct_support_code(self):
return ""
def struct_typedefs(self):
return ''.join(["typedef %s %s_type;\n" % (c_type, name) for c_type, name, init in self.provides()])
class int_converter(omega_type_converter_extension, c_spec.int_converter):
pass
class float_converter(omega_type_converter_extension, c_spec.float_converter):
pass
class complex_converter(omega_type_converter_extension, c_spec.complex_converter):
pass
class unicode_converter(omega_type_converter_extension, c_spec.unicode_converter):
def provides(self):
tvars = self.template_vars()
return omega_type_converter_extension.provides() + [('int', 'N%(name)s' % tvars, 'PyUnicode_GET_SIZE(%(py_var)s)' % tvars)]
class string_converter(omega_type_converter_extension, c_spec.string_converter):
pass
class list_converter(omega_type_converter_extension, c_spec.list_converter):
pass
class dict_converter(omega_type_converter_extension, c_spec.dict_converter):
pass
class tuple_converter(omega_type_converter_extension, c_spec.tuple_converter):
pass
class file_converter(omega_type_converter_extension, c_spec.file_converter):
pass
class instance_converter(omega_type_converter_extension, c_spec.instance_converter):
pass
class array_converter(omega_type_converter_extension, standard_array_spec.array_converter):
def provides(self):
tvars = self.template_vars()
ret = []
ret.append((tvars['c_type'], tvars['array_name'], tvars['var_convert']))
ret.append(('npy_intp*', 'N%(name)s' % tvars, '%(array_name)s->dimensions' % tvars))
ret.append(('npy_intp*', 'S%(name)s' % tvars, '%(array_name)s->strides' % tvars))
ret.append(('int', 'D%(name)s' % tvars, '%(array_name)s->nd' % tvars))
ret.append(('%(num_type)s*' % tvars, '%(name)s' % tvars, '(%(num_type)s*) %(array_name)s->data' % tvars))
return ret
def declaration_code(self, templatize = 0, inline = 0):
tvars = self.template_vars(inline=inline)
tvars['cap_name'] = self.name.upper()
prov = self.provides()
code = '%(py_var)s = %(var_lookup)s;\n' % tvars
code += "\n".join(self.format_provide(export) for export in prov[:1])
code += '\nconversion_numpy_check_type(%(array_name)s,%(num_typecode)s,"%(name)s");\n' % tvars
code += "\n".join(self.format_provide(export) for export in prov[1:])
return code
def struct_support_code(self, templatize = 0, inline = 0):
tvars = self.template_vars(inline=inline)
cap_name = self.name.upper()
tvars['cap_name'] = cap_name
code = 'inline %(num_type)s& %(cap_name)s1(int i) { return (*((%(num_type)s*)(%(array_name)s->data + (i)*S%(name)s[0])));}\n' \
'inline %(num_type)s& %(cap_name)s2(int i, int j) { return (*((%(num_type)s*)(%(array_name)s->data + (i)*S%(name)s[0] + (j)*S%(name)s[1])));}\n' \
'inline %(num_type)s& %(cap_name)s3(int i, int j, int k) { return (*((%(num_type)s*)(%(array_name)s->data + (i)*S%(name)s[0] + (j)*S%(name)s[1] + (k)*S%(name)s[2])));}\n' \
'inline %(num_type)s& %(cap_name)s4(int i, int j, int k, int l) { return (*((%(num_type)s*)(%(array_name)s->data + (i)*S%(name)s[0] + (j)*S%(name)s[1] + (k)*S%(name)s[2] + (l)*S%(name)s[3])));}\n'
return code % tvars
def struct_typedefs(self):
tvars = self.template_vars()
return omega_type_converter_extension.struct_typedefs(self) + "\n" + "typedef %(num_type)s %(name)s_dtype;" % tvars
# return "\n".join(["typedef %s %s_type;" % (c_type, name)])
# def struct_template_types(self):
# tvars = self.template_vars()
# return [("typename %s_type" % name, c_type) for c_type, name, init in self.provides()] + [("typename %s_dtype" % self.name, tvars['num_type'])]
default = [array_converter(),
int_converter(),
float_converter(),
complex_converter(),
unicode_converter(),
string_converter(),
list_converter(),
dict_converter(),
tuple_converter(),
file_converter(),
instance_converter()]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论