提交 d18eacd8 authored 作者: James Bergstra's avatar James Bergstra

merged

......@@ -156,7 +156,7 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
self.debug_all_apps.remove(app)
#UPDATE self.clients
for i, input in enumerate(app.inputs):
for i, input in enumerate(set(app.inputs)):
del self.clients[input][app]
if getattr(app.op, 'destroy_map', {}):
......
......@@ -227,6 +227,7 @@ class Env(utils.object2):
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(env, node, i, old_r, new_r)
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == 'output':
r = self.outputs[i]
if not r.type == new_r.type:
......
......@@ -112,6 +112,8 @@ class Linker(object):
class Container(object):
def __init__(self, r, storage, readonly = False, strict = False, name = None):
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
#self.r = r
if isinstance(r, Type):
self.type = r
......@@ -127,6 +129,9 @@ class Container(object):
if self.readonly:
raise Exception("Cannot set readonly storage: %s" % self.name)
try:
if value is None:
self.storage[0] = None
return
if self.strict:
self.storage[0] = self.type.filter(value, strict = True)
else:
......
from functools import partial
import graph
import sys
class AlreadyThere(Exception):
......@@ -97,7 +98,12 @@ class ReplaceValidate(History, Validator):
def replace_all_validate(self, env, replacements):
chk = env.checkpoint()
for r, new_r in replacements:
env.replace(r, new_r)
try:
env.replace(r, new_r)
except Exception, e:
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e
env.revert(chk) # this might fail; env.replace should never raise an exception (it kinda needs better internal error handling)
raise
try:
env.validate()
except:
......
import theano
from theano import gof
from collections import defaultdict
from itertools import chain
from theano.gof.utils import scratchpad
from copy import copy
def join(*args):
return ".".join(arg for arg in args if arg)
def split(sym, n=-1):
return sym.split('.', n)
class KlassComponent(object):
_name = ""
def bind(self, klass, name):
if self.bound():
raise Exception("%s is already bound to %s as %s" % (self, self.klass, self.name))
self.klass = klass
self.name = join(klass.name, name)
def bound(self):
return hasattr(self, 'klass')
def __repr__(self):
return str(self)
def __str__(self):
return self.__class__.__name__
def __get_name__(self):
return self._name
def __set_name__(self, name):
self._name = name
name = property(lambda self: self.__get_name__(),
lambda self, value: self.__set_name__(value))
class KlassResult(KlassComponent):
def __init__(self, r):
self.r = r
def __set_name__(self, name):
super(KlassResult, self).__set_name__(name)
self.r.name = name
def __str__(self):
return "%s(%s)" % (self.__class__.__name__, self.r)
class KlassMember(KlassResult):
def __init__(self, r):
if r.owner:
raise ValueError("A KlassMember must not be the result of a previous computation.")
super(KlassMember, self).__init__(r)
class KlassMethod(KlassComponent):
def __init__(self, inputs, outputs, updates = {}, **kwupdates):
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
self.inputs = inputs
self.outputs = outputs
self.updates = dict(updates, **kwupdates)
def bind(self, klass, name):
super(KlassMethod, self).bind(klass, name)
self.inputs = [klass.resolve(i, KlassResult).r for i in self.inputs]
self.outputs = [klass.resolve(o, KlassResult).r for o in self.outputs] \
if isinstance(self.outputs, (list, tuple)) \
else klass.resolve(self.outputs, KlassResult).r
updates = self.updates
self.updates = {}
self.extend(updates)
def extend(self, updates = {}, **kwupdates):
if not hasattr(self, 'klass'):
self.updates.update(updates)
self.updates.update(kwupdates)
else:
for k, v in chain(updates.iteritems(), kwupdates.iteritems()):
k, v = self.klass.resolve(k, KlassMember), self.klass.resolve(v, KlassResult)
self.updates[k.r] = v.r
def __str__(self):
return "KlassMethod(%s -> %s%s%s)" % \
(self.inputs,
self.outputs,
"; " if self.updates else "",
", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems()))
class Klass(KlassComponent):
def __new__(cls, *args, **kwargs):
self = object.__new__(cls)
self.__dict__['__components__'] = {}
self.__dict__['_name'] = ""
self.__dict__['__components_list__'] = []
self.__dict__['__component_names__'] = []
return self
###
### Access to the klass members and methods
###
def resolve(self, symbol, filter = None):
if isinstance(symbol, gof.Result):
if not filter or filter is KlassResult:
return KlassResult(symbol)
for component in self.__components_list__:
if isinstance(component, Klass):
try:
return component.resolve(symbol, filter)
except:
continue
if isinstance(component, KlassResult) and component.r is symbol:
if filter and not isinstance(component, filter):
raise TypeError('Did not find a %s instance for symbol %s in klass %s (found %s)'
% (filter.__name__, symbol, self, type(component).__name__))
return KlassResult(symbol)
raise ValueError('%s is not part of this klass or any of its inner klasses. Please add it to the structure before you use it.' % symbol)
elif isinstance(symbol, str):
sp = split(symbol, 1)
if len(sp) == 1:
try:
result = self.__components__[symbol]
except KeyError:
raise AttributeError('Could not resolve symbol %s in klass %s' % (symbol, self))
if filter and not isinstance(result, filter):
raise TypeError('Did not find a %s instance for symbol %s in klass %s (found %s)'
% (filter.__name__, symbol, self, type(result).__name__))
return result
else:
sp0, spr = sp
klass = self.__components__[sp0]
if not isinstance(klass, Klass):
raise TypeError('Could not get subattribute %s of %s' % (spr, klass))
return klass.resolve(spr, filter)
else:
raise TypeError('resolve takes a string or Result argument, not %s' % symbol)
def members(self, as_results = False):
filtered = [x for x in self.__components_list__ if isinstance(x, KlassMember)]
if as_results:
return [x.r for x in filtered]
else:
return filtered
def methods(self):
filtered = [x for x in self.__components_list__ if isinstance(x, KlassMethod)]
return filtered
def member_klasses(self):
filtered = [x for x in self.__components_list__ if isinstance(x, Klass)]
return filtered
###
### Make
###
def __make__(self, mode, stor = None):
if stor is None:
stor = scratchpad()
self.initialize_storage(stor)
members = []
methods = []
rval = KlassInstance()
for component, name in zip(self.__components_list__, self.__component_names__):
if isinstance(component, KlassMember):
container = getattr(stor, name)
members.append((component, container))
rval.__finder__[name] = container
elif isinstance(component, Klass):
inner, inner_members = component.__make__(mode, getattr(stor, name))
rval.__dict__[name] = inner
members += inner_members
elif isinstance(component, KlassMethod):
methods.append(component)
for method in methods:
inputs = list(method.inputs)
for (component, container) in members:
r = component.r
update = method.updates.get(component.r, component.r)
inputs.append(theano.In(result = r,
update = update,
value = container,
name = r.name and split(r.name)[-1],
mutable = True,
strict = True))
fn = theano.function(inputs,
method.outputs,
mode = mode)
rval.__dict__[split(method.name)[-1]] = fn
return rval, members
def make(self, mode = 'FAST_RUN', **init):
rval = self.__make__(mode)[0]
self.initialize(rval, **init)
return rval
###
### Instance setup and initialization
###
def initialize_storage(self, stor):
if not hasattr(stor, '__mapping__'):
stor.__mapping__ = {}
mapping = stor.__mapping__
for name, component in self.__components__.iteritems():
if isinstance(component, Klass):
sp = scratchpad()
setattr(stor, name, sp)
sp.__mapping__ = mapping
component.initialize_storage(sp)
elif isinstance(component, KlassMember):
r = component.r
if r in mapping:
container = mapping[r]
else:
container = gof.Container(r.type,
name = name,
storage = [None])
mapping[r] = container
setattr(stor, name, container)
def initialize(self, inst, **init):
for k, v in init.iteritems():
inst[k] = v
###
### Magic methods and witchcraft
###
def __setattr__(self, attr, value):
if attr == 'name':
self.__set_name__(value)
return
elif attr in ['_name', 'klass']:
self.__dict__[attr] = value
return
if isinstance(value, gof.Result):
value = KlassResult(value)
if isinstance(value, KlassComponent):
value.bind(self, attr)
else:
self.__dict__[attr] = value
return
self.__components__[attr] = value
self.__components_list__.append(value)
self.__component_names__.append(attr)
if isinstance(value, KlassResult):
value = value.r
self.__dict__[attr] = value
def __set_name__(self, name):
orig = self.name
super(Klass, self).__set_name__(name)
for component in self.__components__.itervalues():
if orig:
component.name = join(name, component.name[len(orig):])
else:
component.name = join(name, component.name)
def __str__(self):
n = len(self.name)
if n: n += 1
member_names = ", ".join(x.name[n:] for x in self.members())
if member_names: member_names = "members: " + member_names
method_names = ", ".join(x.name[n:] for x in self.methods())
if method_names: method_names = "methods: " + method_names
klass_names = ", ".join(x.name[n:] for x in self.member_klasses())
if klass_names: klass_names = "inner: " + klass_names
return "Klass(%s)" % "; ".join(x for x in [self.name, member_names, method_names, klass_names] if x)
class KlassInstance(object):
def __init__(self):
self.__dict__['__finder__'] = {}
def __getitem__(self, attr):
if isinstance(attr, str):
attr = split(attr, 1)
if len(attr) == 1:
return self.__finder__[attr[0]].value
else:
return getattr(self, attr[0])[attr[1]]
else:
raise TypeError('Can only get an item via string format: %s' % attr)
def __setitem__(self, attr, value):
if isinstance(attr, str):
attr = split(attr, 1)
if len(attr) == 1:
self.__finder__[attr[0]].value = value
else:
getattr(self, attr[0])[attr[1]] = value
else:
raise TypeError('Can only set an item via string format: %s' % attr)
def __getattr__(self, attr):
return self[attr]
def __setattr__(self, attr, value):
self[attr] = value
......@@ -1212,7 +1212,7 @@ class Subtensor(Op):
def __init__(self, idx_list):
def convert(entry, slice_ok=True):
scal_types =[scal.int64, scal.int32, scal.int16, scal.int8]
scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar]
if isinstance(entry, gof.Result) and entry.type in scal_types:
return entry.type
......@@ -1358,6 +1358,10 @@ class SetSubtensor(Subtensor):
x.__setitem__(cdata, y)
out[0] = x
def split(x, splits_size, n_splits, axis=0):
the_split = Split(n_splits)
return the_split(x, axis, splits_size)
class Split(Op):
"""Partition a `TensorResult` along some axis.
......@@ -1366,9 +1370,9 @@ class Split(Op):
x = vector()
splits = lvector()
# you have to declare right away how many split_points there will be.
ra, rb, rc = split(x, axis=0, points=splits, n_splits=3)
ra, rb, rc = split(x, splits, n_splits = 3, axis = 0)
f = compile([x, splits], [ra, rb, rc])
f = function([x, splits], [ra, rb, rc])
a, b, c = f([0,1,2,3,4,5,6], [3, 2, 1])
......@@ -2055,7 +2059,7 @@ def grad(cost, wrt, g_cost=None):
Tensor(dtype = p.type.dtype, broadcastable = []),
numpy.asarray(0, dtype=p.type.dtype))
if isinstance(wrt, list):
if isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt]
else:
return gmap.get(wrt, zero(wrt))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论