提交 bc3f8854 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

small cleanup

上级 a0a9ab02
import gof
from core import *
from opt import *
from compile import *
from grad import *
...@@ -11,11 +11,11 @@ def experimental_linker(env, target = None): ...@@ -11,11 +11,11 @@ def experimental_linker(env, target = None):
def fetch(op): def fetch(op):
try: try:
factory = op.c_thunk_factory() factory = op.c_thunk_factory()
print "yea %s" % op # print "yea %s" % op
thunk = factory() thunk = factory()
return lambda: cutils.run_cthunk(thunk) return lambda: cutils.run_cthunk(thunk)
except NotImplementedError: except NotImplementedError:
print "nope %s" % op # print "nope %s" % op
return op._perform return op._perform
order = env.toposort() order = env.toposort()
for op in order: for op in order:
......
import os
import sys
import gof import gof
from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode, pop_mode, UNCOMPUTED, UNDEFINED, PythonR from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode, pop_mode, UNCOMPUTED, UNDEFINED, PythonR
...@@ -250,14 +253,15 @@ class omega_op(gof.PythonOp): ...@@ -250,14 +253,15 @@ class omega_op(gof.PythonOp):
instantiate.customize.add_library(lib) instantiate.customize.add_library(lib)
mod.add_function(instantiate) mod.add_function(instantiate)
mod.compile(location = 'compiled') module_dir = os.path.expanduser('~/.omega/compiled')
mod.compile(location = module_dir)
module = __import__("compiled.%s" % module_name, {}, {}, [module_name]) sys.path.insert(0, module_dir)
module = __import__("%s" % module_name) #, {}, {}, [module_name])
sys.path = sys.path[1:]
def creator(): def creator():
return module.instantiate(*[x.data for x in self.inputs + self.outputs]) return module.instantiate(*[x.data for x in self.inputs + self.outputs])
# def creator():
# return weave.inline(code, d.keys(), local_dict = d, global_dict = {}, support_code = struct, type_converters = converters)
return creator return creator
def c_thunk(self): def c_thunk(self):
...@@ -268,26 +272,6 @@ class omega_op(gof.PythonOp): ...@@ -268,26 +272,6 @@ class omega_op(gof.PythonOp):
cutils.run_cthunk(thunk) cutils.run_cthunk(thunk)
# def elemwise_wrap_old(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars):
# return """
# %(beforeloop)s
# for (int i = 0; i < N_%(v1)s[0]; i++) {
# for (int j = 0; j < N_%(v1)s[1]; j++) {
# %(idefs)s
# %(odefs)s
# %(inloop)s
# }
# }
# %(afterloop)s
# """ % dict(v1 = (loop_vars + writable_loop_vars)[0],
# idefs = "\n".join(["_%s_dtype %s = _%s2(i, j);" % (loop_var, loop_var, loop_var.upper())
# for loop_var in loop_vars]),
# odefs = "\n".join(["_%s_dtype& %s = _%s2(i, j);" % (writable_loop_var, writable_loop_var, writable_loop_var.upper())
# for writable_loop_var in writable_loop_vars]),
# beforeloop = beforeloop,
# inloop = inloop,
# afterloop = afterloop)
def elemwise_loopcode(loopcode, init_template, next_template, acquire_template, cleanup_template, loop_vars, writable_loop_vars, aliases): def elemwise_loopcode(loopcode, init_template, next_template, acquire_template, cleanup_template, loop_vars, writable_loop_vars, aliases):
all_loop_vars = loop_vars + writable_loop_vars all_loop_vars = loop_vars + writable_loop_vars
...@@ -511,12 +495,6 @@ class elemwise(omega_op): ...@@ -511,12 +495,6 @@ class elemwise(omega_op):
during = self._c_foreach() during = self._c_foreach()
after = self._c_finalize() after = self._c_finalize()
# # Sanity check - apart from loop vars, variables are shared in the before/during/after parts
# if before and spec_b != spec_d:
# raise Exception("The input signature of c_init differs from the input signature of c_foreach.")
# if after and spec_a != spec_d:
# raise Exception("The input signature of c_finalize differs from the input signature of c_foreach.")
(inames, onames) = self.variable_names() (inames, onames) = self.variable_names()
(linames, lonames) = self.loop_variables() (linames, lonames) = self.loop_variables()
...@@ -531,18 +509,12 @@ class elemwise(omega_op): ...@@ -531,18 +509,12 @@ class elemwise(omega_op):
behavior = elemwise_wrap(before, during, after, behavior = elemwise_wrap(before, during, after,
[name for name in linames if name not in aliases], [name for name in linames if name not in aliases],
lonames, lonames,
# [mangle(iname) for iname in inames if iname.endswith("_i") and not iname in aliases],
# [mangle(oname) for oname in onames if oname.endswith("_i")],
aliases) aliases)
# inames = [mangle(name) for name in inames]
# onames = [mangle(name) for name in onames]
return cgen(self.__class__.__name__, behavior, inames + onames, self.inputs + self.outputs, converters) return cgen(self.__class__.__name__, behavior, inames + onames, self.inputs + self.outputs, converters)
@classmethod @classmethod
def inplace_version(cls, dmap = {0: 0}): def inplace_version(cls, dmap = {0: 0}):
# (inames, onames), _1, _2, _3 = inspect.getargspec(cls._c_foreach)
inames, onames = cls.variable_names() inames, onames = cls.variable_names()
linames, lonames = cls.loop_variables() linames, lonames = cls.loop_variables()
for i, oname in enumerate(onames): for i, oname in enumerate(onames):
...@@ -909,14 +881,14 @@ class dot(omega_op): ...@@ -909,14 +881,14 @@ class dot(omega_op):
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1])) if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{ {
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n"); fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead (1)\\n");
return 1; return 1;
} }
if ((Sx[0] < 1) || (Sx[1] < 1) if ((Sx[0] < 1) || (Sx[1] < 1)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] < 1) || (Sy[1] < 1)
|| (Sz[0] < 1) || (Sz[1] < 1)) || (Sz[0] < 1) || (Sz[1] < 1))
{ {
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n"); fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead (2)\\n");
return 1; return 1;
//return mat_gemm_general(a, A, B, b, C); //return mat_gemm_general(a, A, B, b, C);
} }
...@@ -953,7 +925,7 @@ class dot(omega_op): ...@@ -953,7 +925,7 @@ class dot(omega_op):
case 0x110: %(gemm)s(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; case 0x110: %(gemm)s(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: %(gemm)s(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; case 0x111: %(gemm)s(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: default:
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n"); fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead (3)\\n");
return 1; return 1;
}; };
/* v 1 */ /* v 1 */
......
...@@ -11,7 +11,6 @@ import graph ...@@ -11,7 +11,6 @@ import graph
__all__ = ['Viewer', 'Destroyer', 'DestroyHandler', 'IONames', 'mark_outputs_as_destroyed'] __all__ = ['Viewer', 'Destroyer', 'DestroyHandler', 'IONames', 'mark_outputs_as_destroyed']
## mul(*3 -> sub(*1 -> zeros((), float, C), sigmoid(dot(sigmoid(dot(*1, *2 -> ones((), float, C))), transpose(*2)))), fill(isqr(*3), 1.0))
class IONames: class IONames:
""" """
...@@ -77,259 +76,6 @@ class IONames: ...@@ -77,259 +76,6 @@ class IONames:
# class DestroyHandler(Listener, Constraint, Orderings):
# def __init__(self, env):
# self.parent = {}
# self.children = {}
# self.destroyers = {}
# self.paths = {}
# self.dups = set()
# self.cycles = set()
# self.env = env
# for input in env.inputs:
# self.parent[input] = None
# self.children[input] = set()
# def __path__(self, r):
# path = self.paths.get(r, None)
# if path:
# return path
# rval = [r]
# r = self.parent[r]
# while r:
# rval.append(r)
# r = self.parent[r]
# rval.reverse()
# for i, x in enumerate(rval):
# self.paths[x] = rval[0:i+1]
# return rval
# def __views__(self, r):
# children = self.children[r]
# if not children:
# return set([r])
# else:
# rval = set([r])
# for child in children:
# rval.update(self.__views__(child))
# return rval
# def __users__(self, r):
# views = self.__views__(r)
# rval = set()
# for view in views:
# for op, i in self.env.clients(view):
# rval.update(op.outputs)
# return rval
# def __pre__(self, op):
# rval = set()
# if op is None:
# return rval
# keep_going = False
# for input in op.inputs:
# foundation = self.__path__(input)[0]
# destroyers = self.destroyers.get(foundation, set())
# if destroyers:
# keep_going = True
# if op in destroyers:
# users = self.__users__(foundation)
# rval.update(users)
# if not keep_going:
# return set()
# rval.update(op.inputs)
# rval.difference_update(op.outputs)
# return rval
# def __detect_cycles_helper__(self, r, seq):
# if r in seq:
# self.cycles.add(tuple(seq[seq.index(r):]))
# return
# pre = self.__pre__(r.owner)
# for r2 in pre:
# self.__detect_cycles_helper__(r2, seq + [r])
# def __detect_cycles__(self, start, just_remove=False):
# users = self.__users__(start)
# users.add(start)
# for user in users:
# for cycle in copy(self.cycles):
# if user in cycle:
# self.cycles.remove(cycle)
# if just_remove:
# return
# for user in users:
# self.__detect_cycles_helper__(user, [])
# def get_maps(self, op):
# dmap = {}
# vmap = {}
# if isinstance(op, DestroyOp):
# dmap = op.destroy_map()
# if isinstance(op, ViewOp):
# vmap = op.view_map()
# return vmap, dmap
# # return getattr(op, 'view_map', lambda:{})(), \
# # getattr(op, 'destroy_map', lambda:{})()
# def on_import(self, op):
# view_map, destroy_map = self.get_maps(op)
# for input in op.inputs:
# self.parent.setdefault(input, None)
# for output in op.outputs:
# views = view_map.get(output, None)
# destroyed = destroy_map.get(output, None)
# if destroyed:
# self.parent[output] = None
# for input in destroyed:
# path = self.__path__(input)
# self.__add_destroyer__(path + [output])
# elif views:
# if len(views) > 1: #views was inputs before?
# raise Exception("Output is a view of too many inputs.")
# self.parent[output] = views[0]
# for input in views:
# self.children[input].add(output)
# else:
# self.parent[output] = None
# self.children[output] = set()
# for output in op.outputs:
# self.__detect_cycles__(output)
# # if destroy_map:
# # print "op: ", op
# # print "ord: ", [str(x) for x in self.orderings()[op]]
# # print
# def on_prune(self, op):
# view_map, destroy_map = self.get_maps(op)
# if destroy_map:
# destroyers = []
# for i, input in enumerate(op.inputs):
# destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
# for destroyer in destroyers:
# path = destroyer.get(op, [])
# if path:
# self.__remove_destroyer__(path)
# if view_map:
# for i, input in enumerate(op.inputs):
# self.children[input].difference_update(op.outputs)
# for output in op.outputs:
# try:
# del self.paths[output]
# except:
# pass
# self.__detect_cycles__(output, True)
# for i, output in enumerate(op.outputs):
# try:
# del self.parent[output]
# except:
# pass
# del self.children[output]
# def __add_destroyer__(self, path):
# foundation = path[0]
# target = path[-1]
# op = target.owner
# destroyers = self.destroyers.setdefault(foundation, {})
# path = destroyers.setdefault(op, path)
# if len(destroyers) > 1:
# self.dups.add(foundation)
# def __remove_destroyer__(self, path):
# foundation = path[0]
# target = path[-1]
# op = target.owner
# destroyers = self.destroyers[foundation]
# del destroyers[op]
# if not destroyers:
# del self.destroyers[foundation]
# elif len(destroyers) == 1 and foundation in self.dups:
# self.dups.remove(foundation)
# def on_rewire(self, clients, r_1, r_2):
# path_1 = self.__path__(r_1)
# path_2 = self.__path__(r_2)
# prev = set()
# for op, i in clients:
# prev.update(op.outputs)
# foundation = path_1[0]
# destroyers = self.destroyers.get(foundation, {}).items()
# for op, path in destroyers:
# if r_1 in path:
# idx = path.index(r_1)
# self.__remove_destroyer__(path)
# if not (idx > 0 and path[idx - 1] in prev):
# continue
# index = path.index(r_1)
# new_path = path_2 + path[index+1:]
# self.__add_destroyer__(new_path)
# for op, i in clients:
# view_map, _ = self.get_maps(op)
# for output, inputs in view_map.items():
# if r_2 in inputs:
# assert self.parent[output] == r_1
# self.parent[output] = r_2
# self.children[r_1].remove(output)
# self.children[r_2].add(output)
# for view in self.__views__(r_1):
# try:
# del self.paths[view]
# except:
# pass
# for view in self.__views__(r_2):
# try:
# del self.paths[view]
# except:
# pass
# self.__detect_cycles__(r_1)
# self.__detect_cycles__(r_2)
# def validate(self):
# if self.dups:
# raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
# elif self.cycles:
# raise InconsistencyError("There are cycles: %s" % self.cycles)
# else:
# return True
# def orderings(self):
# ords = {}
# for foundation, destroyers in self.destroyers.items():
# for op in destroyers.keys():
# ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
# return ords
class DestroyHandler(Listener, Constraint, Orderings): class DestroyHandler(Listener, Constraint, Orderings):
def __init__(self, env): def __init__(self, env):
...@@ -635,13 +381,3 @@ class Return(DummyOp, Destroyer): ...@@ -635,13 +381,3 @@ class Return(DummyOp, Destroyer):
def mark_outputs_as_destroyed(outputs): def mark_outputs_as_destroyed(outputs):
return [Return(output).out for output in outputs] return [Return(output).out for output in outputs]
# class BuildableFromInputs:
# @classmethod
# def from_inputs(cls, *inputs):
# return cls(inputs, self.gen_outputs())
# def gen_outputs(self):
# raise NotImplementedError
...@@ -21,9 +21,7 @@ __all__ = ['UNCOMPUTED', ...@@ -21,9 +21,7 @@ __all__ = ['UNCOMPUTED',
'DummyRemover', 'DummyRemover',
'PythonOp', 'PythonOp',
'PythonOpt', 'PythonOpt',
'COp', 'make_static']
'make_static',
'DualImplOp']
UNCOMPUTED = Keyword("UNCOMPUTED", False) UNCOMPUTED = Keyword("UNCOMPUTED", False)
...@@ -190,19 +188,6 @@ class PythonOp(Op): ...@@ -190,19 +188,6 @@ class PythonOp(Op):
answer |= input.constant answer |= input.constant
return answer return answer
# def input_is_up_to_date(self, input):
# if not input.up_to_date:
# return False
# owner = input.owner
# if owner and isinstance(owner, ext.Viewer):
# view_map = owner.view_map()
# if input in view_map:
# answer = True
# for input2 in view_map[input]:
# answer &= owner.input_is_up_to_date(input2)
# return answer
# return True
def check_input(self, input): def check_input(self, input):
if input.data is UNCOMPUTED: if input.data is UNCOMPUTED:
raise ValueError("Uncomputed input: %s in %s" % (input, self)) raise ValueError("Uncomputed input: %s in %s" % (input, self))
...@@ -259,10 +244,6 @@ class PythonOp(Op): ...@@ -259,10 +244,6 @@ class PythonOp(Op):
@classmethod @classmethod
def set_impl(cls, impl): def set_impl(cls, impl):
make_static(cls, 'impl') make_static(cls, 'impl')
# impl = cls.impl
# if hasattr(cls.impl, 'im_func'):
# impl = impl.im_func
# cls.impl = staticmethod(impl)
def impl(*args): def impl(*args):
raise NotImplementedError("This op has no implementation.") raise NotImplementedError("This op has no implementation.")
...@@ -373,99 +354,3 @@ class DummyOp(Op): ...@@ -373,99 +354,3 @@ class DummyOp(Op):
DummyRemover = opt.OpRemover(DummyOp) DummyRemover = opt.OpRemover(DummyOp)
# literals_db = {}
# def literal(x):
# if x in literals_db:
# return literals_db.get(x)
# else:
# ret = PythonR(x, constant = True)
# liberals_db[x] = ret
# return ret
class COp(Op):
def thunk(self):
cc.compile([self])
def c_libs(self):
return []
def c_imports(self):
return []
def c_impl(self):
raise NotImplementedError("Provide the operation's behavior here.")
class DualImplOp(PythonOp, COp):
language = 'c'
supported_languages = 'c', 'python'
def thunk(self, language = None):
"""
Returns a thunk that does the operation on the inputs and stores the
results in the outputs. The language parameter defaults to self.language
and determines which implementation to use.
"""
if not language:
language = self.language
if language == 'c':
return COp.thunk(self)
elif language == 'python':
return PythonOp.thunk(self)
elif language == 'all':
return [self.thunk(lang) for lang in self.supported_languages]
else:
raise ValueError("language should be any of %s or 'all', not '%s'" % (self.supported_languages, language))
def compare_implementations(self,
samples,
setter = lambda res, v: res.set_value(v),
cmp = lambda x, y: x == y):
"""
Compares the different implementations of this operation on a
list of input values to verify that they behave the same. The
input values are put in the Result instances using the setter
function (defaults to set_value). The output lists are
compared using the cmp predicate (defaults to ==).
"""
for sample in samples:
for input, v in zip(self.inputs, sample):
input.set_value(v)
self.thunk('python')()
# we must copy the outputs because they will be overwritten
results_py = [copy(output).extract() for output in self.outputs]
# we redo the assignment because the Op might be destructive,
# in which case the inputs might not be correct anymore
for input, v in zip(self.inputs, sample):
input.set_value(v)
self.thunk('c')()
results_c = [copy(output).extract() for output in self.outputs]
assert cmp(results_py, results_c)
...@@ -54,12 +54,10 @@ class SeqOptimizer(Optimizer, list): ...@@ -54,12 +54,10 @@ class SeqOptimizer(Optimizer, list):
optimizer.optimize(env) optimizer.optimize(env)
def __str__(self): def __str__(self):
#return list.__str__(self)
return "SeqOpt(%s)" % list.__str__(self) return "SeqOpt(%s)" % list.__str__(self)
def __repr__(self): def __repr__(self):
return list.__repr__(self) return list.__repr__(self)
#return "SeqOpt(%s)" % list.__repr__(self)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论