removed dead code, installed ResultBase to gof

上级 b0b514e4
...@@ -2,7 +2,8 @@ import unittest, os, sys ...@@ -2,7 +2,8 @@ import unittest, os, sys
if __name__ == '__main__': if __name__ == '__main__':
suite = None suite = None
for filename in os.listdir('.'): filenames = os.listdir('.') + ['gof.'+s for s in os.listdir('gof')]
for filename in filenames:
if filename[-3:] == '.py': if filename[-3:] == '.py':
modname = filename[:-3] modname = filename[:-3]
tests = unittest.TestLoader().loadTestsFromModule(__import__(modname)) tests = unittest.TestLoader().loadTestsFromModule(__import__(modname))
......
...@@ -12,7 +12,7 @@ from scipy import weave ...@@ -12,7 +12,7 @@ from scipy import weave
import gof import gof
from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode
from gof import pop_mode, is_result from gof import pop_mode, is_result, ResultBase
import type_spec import type_spec
import cutils import cutils
...@@ -84,138 +84,6 @@ def _compile_dir(): ...@@ -84,138 +84,6 @@ def _compile_dir():
sys.path.append(cachedir) sys.path.append(cachedir)
return cachedir return cachedir
class ResultBase(object):
"""Base class for storing Op inputs and outputs
Attributes:
_role - None or (owner, index) or BrokenLink
_data - anything
constant - Boolean
Properties:
role - (rw)
owner - (ro)
index - (ro)
data - (rw)
replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Abstract Methods:
data_filter
"""
class BrokenLink:
"""The owner of a Result that was replaced by another Result"""
__slots__ = ['old_role']
def __init__(self, role): self.old_role = role
def __nonzero__(self): return False
class BrokenLinkError(Exception):
"""Exception thrown when an owner is a BrokenLink"""
class AbstractFunction(Exception):
"""Exception thrown when an abstract function is called"""
__slots__ = ['_role', '_data', 'constant']
def __init__(self, role=None, data=None, constant=False):
self._role = role
self.constant = constant
if data is None: #None is not filtered
self._data = None
else:
try:
self._data = self.data_filter(data)
except ResultBase.AbstractFunction:
self._data = data
#role is pair: (owner, outputs_position)
def __get_role(self):
return self._role
def __set_role(self, role):
owner, index = role
if self._role is not None:
# this is either an error or a no-op
_owner, _index = self._role
if _owner is not owner:
raise ValueError("Result %s already has an owner." % self)
if _index != index:
raise ValueError("Result %s was already mapped to a different index." % self)
return # because _owner is owner and _index == index
self._role = role
role = property(__get_role, __set_role)
#owner is role[0]
def __get_owner(self):
if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[0]
owner = property(__get_owner,
doc = "Op of which this Result is an output, or None if role is None")
#index is role[1]
def __get_index(self):
if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[1]
index = property(__get_index,
doc = "position of self in owner's outputs, or None if role is None")
# assigning to self.data will invoke self.data_filter(value) if that
# function is defined
def __get_data(self):
return self._data
def __set_data(self, data):
if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase')
try:
self._data = self.data_filter(data)
except ResultBase.AbstractFunction: #use default behaviour
self._data = data
data = property(__get_data, __set_data,
doc = "The storage associated with this result")
def data_filter(self, data):
"""(abstract) Return an appropriate _data based on data."""
raise ResultBase.AbstractFunction()
# replaced
def __get_replaced(self): return isinstance(self._role, ResultBase.BrokenLink)
def __set_replaced(self, replace):
if replace == self.replaced: return
if replace:
self._role = ResultBase.BrokenLink(self._role)
else:
self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
# computed
#TODO: think about how to handle this more correctly
computed = property(lambda self: self._data is not None)
#################
# NumpyR Compatibility
#
up_to_date = property(lambda self: True)
def refresh(self): pass
def set_owner(self, owner, idx):
self.role = (owner, idx)
def set_value(self, value):
self.data = value #may raise exception
class _test_ResultBase(unittest.TestCase):
def setUp(self):
build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
pop_mode()
def test_0(self):
r = ResultBase()
class Numpy2(ResultBase): class Numpy2(ResultBase):
"""Result storing a numpy ndarray""" """Result storing a numpy ndarray"""
__slots__ = ['_dtype', '_shape', ] __slots__ = ['_dtype', '_shape', ]
...@@ -986,92 +854,14 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None): ...@@ -986,92 +854,14 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
return normal_f(x, y) return normal_f(x, y)
return f return f
if 0:
class NumpyR(gof.ResultValue):
"""The class for storing ndarray return values from omega ops.
The class provides additional functionality compared to the normal
ResultValue:
- operator overloads that correspond to omega ops such as add() and scale()
- special attributes that make it behave like an ndarray when passed to
numpy functions.
Attributes:
__array__ - alias of self.data.__array_struct__
__array_struct__ - alias of self.data.__array_struct__
Methods:
set_value() -
"""
# The following attributes make NumpyR instances look like normal ndarray
# instances to many numpy functions, such as argmax(), dot(), svd(), sum(),
# etc. These are documented in the numpy book.
__array__ = property(lambda self: self.data.__array__ )
__array_struct__ = property(lambda self: self.data.__array_struct__ )
def set_value_filter(self, value): return numpy.asarray(value)
def set_value_inplace(self, value):
if value is UNCOMPUTED:
raise ValueError()
else:
if 0 == len(self.data.shape):
self.data.itemset(value) # for scalars
else:
self.data[:] = value # for matrices
self.refresh()
self.up_to_date = True
def refresh(self):
if self.data is not UNCOMPUTED:
self.spec = (numpy.ndarray, self.data.dtype, self.data.shape)
def alloc(self):
shape, dtype = self.spec[2], self.spec[1]
self.data = numpy.ndarray(shape, dtype=dtype)
def __add__(self, y): return add(self, y)
def __radd__(self, x): return add(x, self)
def __iadd__(self, y): return add_inplace(self, y)
def __sub__(self, y): return sub(self, y)
def __rsub__(self, x): return sub(x, self)
def __isub__(self, y): return sub_inplace(self, y)
def __mul__(self, y): return mul(self, y)
def __rmul__(self, x): return mul(x, self)
def __imul__(self, y): return mul_inplace(self, y)
def __div__(self, y): return div(self, y)
def __rdiv__(self, x): return div(x, self)
def __idiv__(self, y): return div_inplace(self, y)
def __pow__(self, y): return pow(self, y)
def __rpow__(self, x): return pow(x, self)
def __ipow__(self, y): return pow_inplace(self, y)
def __neg__(self): return neg(self)
T = property(lambda self: transpose(self))
Tc = property(lambda self: transpose_copy(self))
def __copy__(self): return array_copy(self)
def __getitem__(self, item): return get_slice(self, item)
def __getslice__(self, *args): return get_slice(self, slice(*args))
from grad import Grad from grad import Undefined
def wrap_producer(f): def wrap_producer(f):
class producer(omega_op): class producer(omega_op):
impl = f impl = f
def grad(*args): def grad(*args):
return [Grad.Undefined] * (len(args) - 1) return [Undefined] * (len(args) - 1)
producer.__name__ = f.__name__ producer.__name__ = f.__name__
def ret(dim, dtype = 'float', order = 'C'): def ret(dim, dtype = 'float', order = 'C'):
return producer(dim, dtype, order) return producer(dim, dtype, order)
...@@ -1811,11 +1601,11 @@ class sum(elemwise): ...@@ -1811,11 +1601,11 @@ class sum(elemwise):
class ones_like(elemwise): class ones_like(elemwise):
impl = numpy.ones_like impl = numpy.ones_like
def grad(x, gz): return Grad.Undefined def grad(x, gz): return Undefined
class zeros_like(elemwise): class zeros_like(elemwise):
impl = numpy.zeros_like impl = numpy.zeros_like
def grad(x, gz): return Grad.Undefined def grad(x, gz): return Undefined
## Array slicing ## ## Array slicing ##
......
import env
import tools
import utils
class Compiler:
""" What is this? Please document.
"""
def __init__(self, optimizer, features):
self.features = set(features)
self.features.update(optimizer.require())
self.optimizer = optimizer
def compile(self, inputs, outputs, features):
features = self.features.union(features)
e = env.Env(inputs, outputs, features, False)
self.optimizer.apply(e)
if not e.consistent():
raise env.InconsistencyError("The graph is inconsistent.")
return e
def __call__(self, inputs, outputs, features):
return self.compile(inputs, outputs, features)
# def __init__(self, inputs, outputs, preprocessors, features, optimizer):
# self.inputs = inputs
# self.outputs = outputs
# self.features = features
# self.optimizer = optimizer
# features = features + [tools.EquivTool] + optimizer.require()
# features = utils.uniq_features(features)
# self.env = env.Env(inputs,
# outputs,
# features,
# False)
# if not self.env.consistent():
# raise env.InconsistencyError("The graph is inconsistent.")
# self.__optimize__()
# self.thunks = [op.thunk() for op in self.order]
# def __optimize__(self):
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
# def equiv(self, r):
# return self.env.equiv(r)
# def __getitem__(self, r):
# return self.equiv(r)
# def __setitem__(self, r, value):
# if isinstance(r, tuple):
# for a, b in zip(r, value):
# self.__setitem__(a, b)
# else:
# self.equiv(r).set_value(value)
# def __call__(self, *args):
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
# import env
# import opt
# from value import AsValue
# class Prog:
# def __init__(self, inputs, outputs, optimizer):
# self.inputs = inputs
# self.outputs = outputs
# self.env = env.Env(inputs,
# outputs,
# False,
# op_db = env.OpDb,
# changed = env.ChangeListener,
# # pr = env.PrintListener,
# scope = env.ScopeListener)
# ## self.adjustments = adjustments
# self.optimizer = optimizer
# ## if self.adjustments:
# ## self.adjustments.apply(self.env)
# if not self.env.consistent():
# raise env.InconsistencyError("The graph is inconsistent.")
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
# print "==================="
# for op in self.order:
# print op
# print "==================="
# self.thunks = [op.thunk() for op in self.order]
# def equiv(self, v):
# v = AsValue(v)
# return self.env.equiv(v)
# def __getitem__(self, v):
# return self.equiv(v).storage
# def __setitem__(self, v, value):
# if isinstance(v, tuple):
# for a, b in zip(v, value):
# self.__setitem__(a, b)
# else:
# self.equiv(v).value = value
# def __call__(self, *args):
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
# def prog(i, o):
# if not isinstance(i, (list, tuple)):
# i = [i]
# if not isinstance(o, (list, tuple)):
# o = [o]
# i = [AsValue(input) for input in i]
# o = [AsValue(output) for output in o]
# return Prog(i,
# o,
# opt.TagFilterMultiOptimizer(opt.opt_registry, None, None))
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from copy import copy from copy import copy
from op import Op from op import Op
from lib import DummyOp from lib import DummyOp
from result import Result
from features import Listener, Constraint, Orderings from features import Listener, Constraint, Orderings
from env import InconsistencyError from env import InconsistencyError
from utils import ClsInit from utils import ClsInit
......
from copy import copy from copy import copy
from result import Result, BrokenLink, BrokenLinkError from result import BrokenLink, BrokenLinkError
from op import Op from op import Op
import utils import utils
......
from op import Op from op import Op
from result import Result, is_result from result import is_result, ResultBase
from utils import ClsInit, Keyword, AbstractFunctionError from utils import ClsInit, Keyword, AbstractFunctionError
import opt import opt
import env import env
...@@ -15,7 +15,6 @@ __all__ = [ 'UNDEFINED', ...@@ -15,7 +15,6 @@ __all__ = [ 'UNDEFINED',
'eval_mode', 'eval_mode',
'build_eval_mode', 'build_eval_mode',
'pop_mode', 'pop_mode',
#'ResultValue',
'DummyOp', 'DummyOp',
'DummyRemover', 'DummyRemover',
'PythonOp', 'PythonOp',
...@@ -98,83 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint): ...@@ -98,83 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
if 0:
class ResultValue(Result):
"""Augment Result to wrap a computed value.
Attributes:
data -
spec -
constant -
up_to_date -
Properties:
Methods:
set_value_filter - ABSTRACT
set_value_inplace - ABSTRACT
alloc - ABSTRACT
Notes:
"""
__slots__ = ['data', 'spec', 'constant', 'up_to_date']
def __init__(self, x = None, constant = False):
self.constant = False
self.set_value(x) # allow set_value before constant = True
self.constant = constant
self.up_to_date = True
self.refresh() # to set spec
def __str__(self): return str(self.data)
def __repr__(self): return repr(self.data)
#TODO: document this function, what does it do?
def refresh(self): self.spec = id(self.data)
####################################################
#
# Functionality provided by this class
#
def set_value(self, value):
if self.constant:
raise Exception("This Result is a constant. Its value cannot be changed.")
if value is None or value is UNCOMPUTED:
self.data = UNCOMPUTED
elif is_result(value):
self.set_value(value.data)
else:
try:
self.data = self.set_value_filter(value)
except AbstractFunctionError, e:
self.data = value
self.up_to_date = True
self.refresh()
####################################################
#
# Pure virtual functions for subclasses to implement
#
# Perform error checking or automatic conversion of value, and return the
# result (which will be stored as self.data)
# Called by: set_value()
def set_value_filter(self, value): raise AbstractFunctionError()
# For mutable data types, overwrite the current contents with value
# Also, call refresh and set up_to_date = True
def set_value_inplace(self, value): raise AbstractFunctionError()
# Instantiate data (according to spec)
def alloc(self): raise AbstractFunctionError()
class DestroyHandler(features.Listener, features.Constraint, features.Orderings): class DestroyHandler(features.Listener, features.Constraint, features.Orderings):
def __init__(self, env): def __init__(self, env):
...@@ -279,14 +201,14 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings) ...@@ -279,14 +201,14 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings)
if destroyed: if destroyed:
# self.parent[output] = None # self.parent[output] = None
if isinstance(destroyed, Result): if is_result(destroyed):
destroyed = [destroyed] destroyed = [destroyed]
for input in destroyed: for input in destroyed:
path = self.__path__(input) path = self.__path__(input)
self.__add_destroyer__(path + [output]) self.__add_destroyer__(path + [output])
elif views: elif views:
if isinstance(views, Result): if is_result(views):
views = [views] views = [views]
if len(views) > 1: #views was inputs before? if len(views) > 1: #views was inputs before?
raise Exception("Output is a view of too many inputs.") raise Exception("Output is a view of too many inputs.")
...@@ -462,7 +384,6 @@ class PythonOp(Op): ...@@ -462,7 +384,6 @@ class PythonOp(Op):
def gen_outputs(self): def gen_outputs(self):
raise NotImplementedError() raise NotImplementedError()
#return [ResultValue() for i in xrange(self.nout)]
def view_map(self): return {} def view_map(self): return {}
...@@ -657,7 +578,7 @@ class PythonOpt(opt.Optimizer): ...@@ -657,7 +578,7 @@ class PythonOpt(opt.Optimizer):
class DummyOp(Op): class DummyOp(Op):
def __init__(self, input): def __init__(self, input):
Op.__init__(self, [input], [Result()]) Op.__init__(self, [input], [ResultBase()])
def thunk(self): def thunk(self):
return lambda:None return lambda:None
......
from copy import copy
import graph
from env import Env, EnvListener
class PrintListener(EnvListener):
def __init__(self, env, active = True):
self.env = env
self.active = active
if active:
print "-- initializing"
def on_import(self, op):
if self.active:
print "-- importing: %s" % graph.as_string(self.env.inputs, op.outputs)
def on_prune(self, op):
if self.active:
print "-- pruning: %s" % graph.as_string(self.env.inputs, op.outputs)
def on_rewire(self, clients, v, new_v):
if self.active:
if v.owner is None:
vg = v.name
else:
vg = graph.as_string(self.env.inputs, v.owner.outputs)
if new_v.owner is None:
new_vg = new_v.name
else:
new_vg = graph.as_string(self.env.inputs, new_v.owner.outputs)
print "-- moving from %s to %s" % (vg, new_vg)
class ChangeListener(EnvListener):
def __init__(self, env):
self.change = False
def on_import(self, op):
self.change = True
def on_prune(self, op):
self.change = True
def on_replace(self, v, new_v):
self.change = True
def __call__(self, value = "get"):
if value == "get":
return self.change
else:
self.change = value
class InstanceFinder(EnvListener, dict):
def __init__(self, env):
self.env = env
def all_bases(self, cls):
rval = set(cls)
for base in cls.__bases__:
rval.add(self.all_bases(base))
return [cls for cls in rval if issubclass(cls, Op)]
def on_import(self, op):
for base in self.all_bases(op.__class__):
self.setdefault(base, set()).add(op)
def on_prune(self, op):
for base in self.all_bases(op.__class__):
self[base].remove(op)
if not self[base]:
del self[base]
def __query__(self, cls):
all = [x for x in self.get(cls, [])]
while all:
next = all.pop()
if next in self.env.ops():
yield next
def query(self, cls):
return self.__query__(cls)
# class GraphOrder(EnvListener, dict):
# def init(self, graph):
# self.graph = graph
# def __adjust__(self, op, minimum):
# if not op or self[op] >= minimum:
# return
# self[op] = minimum
# for output in op.outputs:
# for client, i in output.clients:
# self.__adjust__(client, minimum + 1)
# def on_import(self, op):
# order = 0
# for input in op.inputs:
# if input not in self.graph.inputs:
# order = max(order, self[input.owner] + 1)
# self[op] = order
# def on_prune(self, op):
# del self[op]
# def on_replace(self, v, new_v):
# self.__adjust__(new_v.owner, self.get(v.owner, 0))
class SuperFinder(EnvListener, dict):
def __init__(self, env, props):
self.env = env
self.props = props
def on_import(self, op):
for prop, value in self.props(op).items():
self.setdefault(prop, {}).setdefault(value, set()).add(op)
def on_prune(self, op):
for prop, value in self.props(op).items():
self[prop][value].remove(op)
if len(self[prop][value]) == 0:
del self[prop][value]
if len(self[prop]) == 0:
del self[prop]
def __query__(self, order, template):
all = []
for prop, value in template.items():
all += [x for x in self.get(prop, {}).get(value, set())]
# If not None, the order option requires the order listener to be included in the env under the name 'order'
if order == 'o->i':
all.sort(lambda op1, op2: self.env.order[op1].__cmp__(self.env.order[op2]))
elif order == 'i->o':
all.sort(lambda op1, op2: self.env.order[op2].__cmp__(self.env.order[op1]))
while all:
next = all.pop()
if next in self.env.ops():
yield next
def query(self, **template):
return self.__query__(None, template)
def query_downstream(self, **template):
return self.__query__('i->o', template)
def query_upstream(self, **template):
return self.__query__('o->i', template)
# class DupListener(EnvListener):
# def __init__(self, env):
# self.to_cid = {}
# self.to_obj = {}
# self.env = env
# for i, input in enumerate(env.inputs):
# self.to_cid[input] = i
# self.to_obj[i] = input
# def init(self, env):
# self.env = env
# for i, input in enumerate(env.inputs):
# self.to_cid[input] = i
# self.to_obj[i] = input
# def on_import(self, op):
# cid = (op.__class__, tuple([self.to_cid[input] for input in op.inputs]))
# self.to_cid[op] = cid
# self.to_obj.setdefault(cid, op)
# for i, output in enumerate(op.outputs):
# ocid = (i, cid)
# self.to_cid[output] = ocid
# self.to_obj.setdefault(ocid, output)
# def on_prune(self, op):
# # we don't delete anything
# return
# def apply(self, env):
# if env is not self.env:
# raise Exception("The DupListener merge optimization can only apply to the env it is listening to.")
# def fn(op):
# op2 = self.to_obj[self.to_cid[op]]
# if op is not op2:
# return [(o, o2) for o, o2 in zip(op.outputs, op2.outputs)]
# env.walk_from_outputs(fn)
# def __call__(self):
# self.apply(self.env)
class DestroyHandler(EnvListener):
def __init__(self, env):
self.parent = {}
self.children = {}
self.destroyers = {}
self.paths = {}
self.dups = set()
self.cycles = set()
self.env = env
for input in env.inputs:
self.parent[input] = None
self.children[input] = set()
def __path__(self, r):
path = self.paths.get(r, None)
if path:
return path
rval = [r]
r = self.parent[r]
while r:
rval.append(r)
r = self.parent[r]
rval.reverse()
for i, x in enumerate(rval):
self.paths[x] = rval[0:i+1]
return rval
def __views__(self, r):
children = self.children[r]
if not children:
return set([r])
else:
rval = set([r])
for child in children:
rval.update(self.__views__(child))
return rval
def __users__(self, r):
views = self.__views__(r)
rval = set()
for view in views:
for op, i in self.env.clients(view):
rval.update(op.outputs)
return rval
def __pre__(self, op):
rval = set()
if op is None:
return rval
keep_going = False
for input in op.inputs:
foundation = self.__path__(input)[0]
destroyers = self.destroyers.get(foundation, set())
if destroyers:
keep_going = True
if op in destroyers:
users = self.__users__(foundation)
rval.update(users)
if not keep_going:
return set()
rval.update(op.inputs)
rval.difference_update(op.outputs)
return rval
def __detect_cycles_helper__(self, r, seq):
if r in seq:
self.cycles.add(tuple(seq[seq.index(r):]))
return
pre = self.__pre__(r.owner)
for r2 in pre:
self.__detect_cycles_helper__(r2, seq + [r])
def __detect_cycles__(self, start, just_remove=False):
users = self.__users__(start)
users.add(start)
for user in users:
for cycle in copy(self.cycles):
if user in cycle:
self.cycles.remove(cycle)
if just_remove:
return
for user in users:
self.__detect_cycles_helper__(user, [])
def get_maps(self, op):
return getattr(op, 'view_map', lambda x:{})(), \
getattr(op, 'destroy_map', lambda x:{})()
def on_import(self, op):
view_map, destroy_map = self.get_maps(op)
for input in op.inputs:
self.parent.setdefault(input, None)
for i, output in enumerate(op.outputs):
views = view_map.get(output, None)
destroyed = destroy_map.get(output, None)
if destroyed:
self.parent[output] = None
for input in destroyed:
path = self.__path__(input)
self.__add_destroyer__(path + [output])
elif views:
if len(inputs) > 1:
raise Exception("Output is a view of too many inputs.")
self.parent[output] = inputs[0]
for input in views:
self.children[input].add(output)
else:
self.parent[output] = None
self.children[output] = set()
for output in op.outputs:
self.__detect_cycles__(output)
# if destroy_map:
# print "op: ", op
# print "ord: ", [str(x) for x in self.orderings()[op]]
# print
def on_prune(self, op):
view_map, destroy_map = self.get_maps(op)
if destroy_map:
destroyers = []
for i, input in enumerate(op.inputs):
destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
for destroyer in destroyers:
path = destroyer.get(op, [])
if path:
self.__remove_destroyer__(path)
if view_map:
for i, input in enumerate(op.inputs):
self.children[input].difference_update(op.outputs)
for output in op.outputs:
try:
del self.paths[output]
except:
pass
self.__detect_cycles__(output, True)
for i, output in enumerate(op.outputs):
del self.parent[output]
del self.children[output]
def __add_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(op, path)
if len(destroyers) > 1:
self.dups.add(foundation)
def __remove_destroyer__(self, path):
foundation = path[0]
target = path[-1]
op = target.owner
destroyers = self.destroyers[foundation]
del destroyers[op]
if not destroyers:
del self.destroyers[foundation]
elif len(destroyers) == 1 and foundation in self.dups:
self.dups.remove(foundation)
def on_rewire(self, clients, r_1, r_2):
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
prev = set()
for op, i in clients:
prev.update(op.outputs)
foundation = path_1[0]
destroyers = self.destroyers.get(foundation, {}).items()
for op, path in destroyers:
if r_1 in path:
idx = path.index(r_1)
self.__remove_destroyer__(path)
if not (idx > 0 and path[idx - 1] in prev):
continue
index = path.index(r_1)
new_path = path_2 + path[index+1:]
self.__add_destroyer__(new_path)
for op, i in clients:
view_map, _ = self.get_maps(op)
for output, inputs in view_map.items():
if r_1 in inputs:
assert self.parent[output] == r_1
self.parent[output] = r_2
self.children[r_1].remove(output)
self.children[r_2].add(output)
for view in self.__views__(r_1):
try:
del self.paths[view]
except:
pass
for view in self.__views__(r_2):
try:
del self.paths[view]
except:
pass
self.__detect_cycles__(r_1)
self.__detect_cycles__(r_2)
def validate(self):
if self.dups:
raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
elif self.cycles:
raise InconsistencyError("There are cycles: %s" % self.cycles)
else:
return True
def orderings(self):
ords = {}
for foundation, destroyers in self.destroyers.items():
for op in destroyers.keys():
ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
return ords
...@@ -4,7 +4,7 @@ Contains the Op class, which is the base interface for all operations ...@@ -4,7 +4,7 @@ Contains the Op class, which is the base interface for all operations
compatible with gof's graph manipulation routines. compatible with gof's graph manipulation routines.
""" """
from result import Result, BrokenLink from result import BrokenLink
from utils import ClsInit, all_bases, all_bases_collect from utils import ClsInit, all_bases, all_bases_collect
from copy import copy from copy import copy
......
...@@ -5,12 +5,13 @@ value that is the input or the output of an Op. ...@@ -5,12 +5,13 @@ value that is the input or the output of an Op.
""" """
import unittest
from err import GofError from err import GofError
from utils import AbstractFunctionError from utils import AbstractFunctionError
__all__ = ['is_result', 'Result', 'BrokenLink', 'BrokenLinkError'] __all__ = ['is_result', 'ResultBase', 'BrokenLink', 'BrokenLinkError']
class BrokenLink: class BrokenLink:
...@@ -45,24 +46,27 @@ def is_result(obj): ...@@ -45,24 +46,27 @@ def is_result(obj):
attr_list = 'owner', attr_list = 'owner',
return all([hasattr(obj, attr) for attr in attr_list]) return all([hasattr(obj, attr) for attr in attr_list])
class Result(object): class ResultBase(object):
"""Storage node for data in a graph of Op instances. """Base class for storing Op inputs and outputs
Attributes: Attributes:
owner - represents the Op which computes this Result. Contains either None _role - None or (owner, index) or BrokenLink
or an instance of Op. _data - anything
index - the index of this Result in owner.outputs. constant - Boolean
Methods: Properties:
- role - (rw)
owner - (ro)
index - (ro)
data - (rw)
replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Notes: Abstract Methods:
data_filter
Result has no __init__ or __new__ routine. It is the Op's
responsibility to set the owner field of its results.
The Result class is abstract. It must be subclassed to support the Notes:
types of data needed for computation.
A Result instance should be immutable: indeed, if some aspect of a A Result instance should be immutable: indeed, if some aspect of a
Result is changed, operations that use it might suddenly become Result is changed, operations that use it might suddenly become
...@@ -71,140 +75,120 @@ class Result(object): ...@@ -71,140 +75,120 @@ class Result(object):
called on the Result which is replaced (this will make its owner a called on the Result which is replaced (this will make its owner a
BrokenLink instance, which behaves like False in conditional BrokenLink instance, which behaves like False in conditional
expressions). expressions).
"""
__slots__ = ['_owner', '_index'] """
class BrokenLink:
def get_owner(self): """The owner of a Result that was replaced by another Result"""
if not hasattr(self, '_owner'): __slots__ = ['old_role']
self._owner = None def __init__(self, role): self.old_role = role
return self._owner def __nonzero__(self): return False
owner = property(get_owner, class BrokenLinkError(Exception):
doc = "The Op of which this Result is an output or None if there is no such Op.") """Exception thrown when an owner is a BrokenLink"""
def set_owner(self, owner, index): class AbstractFunction(Exception):
if self.owner is not None: """Exception thrown when an abstract function is called"""
if self.owner is not owner:
__slots__ = ['_role', '_data', 'constant']
def __init__(self, role=None, data=None, constant=False):
self._role = role
self.constant = constant
if data is None: #None is not filtered
self._data = None
else:
try:
self._data = self.data_filter(data)
except ResultBase.AbstractFunction:
self._data = data
#role is pair: (owner, outputs_position)
def __get_role(self):
return self._role
def __set_role(self, role):
owner, index = role
if self._role is not None:
# this is either an error or a no-op
_owner, _index = self._role
if _owner is not owner:
raise ValueError("Result %s already has an owner." % self) raise ValueError("Result %s already has an owner." % self)
elif self.index != index: if _index != index:
raise ValueError("Result %s was already mapped to a different index." % self) raise ValueError("Result %s was already mapped to a different index." % self)
self._owner = owner return # because _owner is owner and _index == index
self._index = index self._role = role
role = property(__get_role, __set_role)
def invalidate(self):
if self.owner is None: #owner is role[0]
raise Exception("Cannot invalidate a Result instance with no owner.") def __get_owner(self):
elif not isinstance(self.owner, BrokenLink): if self._role is None: return None
self._owner = BrokenLink(self._owner, self._index) if self.replaced: raise ResultBase.BrokenLinkError()
del self._index return self._role[0]
owner = property(__get_owner,
def revalidate(self): doc = "Op of which this Result is an output, or None if role is None")
if isinstance(self.owner, BrokenLink):
owner, index = self._owner.owner, self._owner.index #index is role[1]
self._owner = owner def __get_index(self):
self._index = index if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError()
def perform(self): return self._role[1]
"""Calls self.owner.perform() if self.owner exists. index = property(__get_index,
doc = "position of self in owner's outputs, or None if role is None")
This is a mutually recursive function with gof.op.Op
""" # assigning to self.data will invoke self.data_filter(value) if that
if self.owner: # function is defined
self.owner.perform() def __get_data(self):
return self._data
def __set_data(self, data):
# def extract(self): if self.replaced: raise ResultBase.BrokenLinkError()
# """ if self.constant: raise Exception('cannot set constant ResultBase')
# Returns a representation of this datum for use in Op.impl. try:
# Successive calls to extract should always return the same object. self._data = self.data_filter(data)
# """ except ResultBase.AbstractFunction: #use default behaviour
# raise NotImplementedError self._data = data
data = property(__get_data, __set_data,
# def sync(self): doc = "The storage associated with this result")
# """
# After calling Op.impl, synchronizes the Result instance with the def data_filter(self, data):
# new contents of the storage. This might usually not be necessary. """(abstract) Return an appropriate _data based on data."""
# """ raise ResultBase.AbstractFunction()
# raise NotImplementedError
# def c_libs(self): # replaced
# """ def __get_replaced(self): return isinstance(self._role, ResultBase.BrokenLink)
# Returns a list of libraries that must be included to work with def __set_replaced(self, replace):
# this Result. if replace == self.replaced: return
# """ if replace:
# raise NotImplementedError self._role = ResultBase.BrokenLink(self._role)
else:
# def c_imports(self): self._role = self._role.old_role
# """ replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
# Returns a list of strings representing headers to import when
# building a C interface that uses this Result. # computed
# """ #TODO: think about how to handle this more correctly
# raise NotImplementedError computed = property(lambda self: self._data is not None)
# def c_declare(self):
# """ #################
# Returns code which declares and initializes a C variable in # NumpyR Compatibility
# which this Result can be held. #
# """ up_to_date = property(lambda self: True)
# raise NotImplementedError def refresh(self): pass
def set_owner(self, owner, idx):
# def pyo_to_c(self): self.role = (owner, idx)
# raise NotImplementedError def set_value(self, value):
self.data = value #may raise exception
# def c_to_pyo(self):
# raise NotImplementedError class _test_ResultBase(unittest.TestCase):
def setUp(self):
build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
############################ pop_mode()
# Utilities def test_0(self):
############################ r = ResultBase()
# class SelfContainedResult(Result):
# """ if __name__ == '__main__':
# This represents a Result which acts as its own data container. It unittest.main()
# is recommended to subclass this if you wish to be able to use the
# Result in normal computations as well as working with a graph
# representation.
# """
# # def extract(self):
# # """Returns self."""
# # return self
# # def sync(self):
# # """Does nothing."""
# # pass
# class HolderResult(Result):
# """
# HolderResult adds a 'data' slot which is meant to contain the
# object used by the Op implementation. It is recommended to subclass
# this if you want to be able to use the exact same object at
# different points in a computation.
# """
# __slots__ = ['data']
# # def extract(self):
# # """Returns self.data."""
# # return self.data
# # def sync(self):
# # """
# # Does nothing. Override if you have additional fields or
# # functionality in your subclass which need to be computed from
# # the data.
# # """
# # pass
...@@ -2,6 +2,9 @@ import gof ...@@ -2,6 +2,9 @@ import gof
from gof.lib import compute_from, is_result from gof.lib import compute_from, is_result
import core import core
class Undefined:
"""A special class representing a gradient of 0"""
class Grad(object): class Grad(object):
"""A dictionary-like class, into which derivative expressions may be added. """A dictionary-like class, into which derivative expressions may be added.
...@@ -17,7 +20,6 @@ class Grad(object): ...@@ -17,7 +20,6 @@ class Grad(object):
__call__() __call__()
__getitem__() __getitem__()
""" """
class Undefined: pass
def __init__(self, dct={}): def __init__(self, dct={}):
self.map = {} self.map = {}
...@@ -36,7 +38,7 @@ class Grad(object): ...@@ -36,7 +38,7 @@ class Grad(object):
try: try:
return self.map[key] return self.map[key]
except KeyError: except KeyError:
return Grad.Undefined return Undefined
def __setitem__(self, item, val): def __setitem__(self, item, val):
"""Map item to its id and store internally.""" """Map item to its id and store internally."""
...@@ -59,7 +61,7 @@ class Grad(object): ...@@ -59,7 +61,7 @@ class Grad(object):
r may be uncomputed or NumpyR r may be uncomputed or NumpyR
""" """
if dr is Grad.Undefined: if dr is Undefined:
# nothing to do # nothing to do
return return
...@@ -124,7 +126,7 @@ class Grad(object): ...@@ -124,7 +126,7 @@ class Grad(object):
if not self.did_bprop: if not self.did_bprop:
raise Exception('Grad.__call__ only makes sense after a bprop') raise Exception('Grad.__call__ only makes sense after a bprop')
rval = self[item] rval = self[item]
if rval is not Grad.Undefined \ if rval is not Undefined \
and core.current_mode() == 'build_eval': and core.current_mode() == 'build_eval':
compute_from([rval], self._compute_history) compute_from([rval], self._compute_history)
return rval return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论