提交 ea4d7e69 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed a couple bugs

上级 5961805c
...@@ -32,11 +32,9 @@ class AllocationError(Exception): ...@@ -32,11 +32,9 @@ class AllocationError(Exception):
class Component(object): class Component(object):
def __new__(cls, *args, **kwargs): def __init__(self):
self = object.__new__(cls) self.__dict__['_name'] = ''
self._name = '' self.__dict__['parent'] = None
self.parent = None
return self
def bind(self, parent, name): def bind(self, parent, name):
if self.bound(): if self.bound():
...@@ -85,21 +83,15 @@ class Component(object): ...@@ -85,21 +83,15 @@ class Component(object):
class External(Component): class _RComponent(Component):
def __init__(self, r): def __init__(self, r):
super(_RComponent, self).__init__()
self.r = r self.r = r
self.owns_name = r.name is None self.owns_name = r.name is None
def allocate(self, memo):
# nothing to allocate
return None
def build(self, mode, memo):
return None
def __set_name__(self, name): def __set_name__(self, name):
super(External, self).__set_name__(name) super(_RComponent, self).__set_name__(name)
if self.owns_name: if self.owns_name:
self.r.name = name self.r.name = name
...@@ -107,15 +99,29 @@ class External(Component): ...@@ -107,15 +99,29 @@ class External(Component):
return "%s(%s)" % (self.__class__.__name__, self.r) return "%s(%s)" % (self.__class__.__name__, self.r)
def pretty(self): def pretty(self):
rval = 'External :: %s' % self.r.type rval = '%s :: %s' % (self.__class__.__name__, self.r.type)
return rval return rval
class Member(Component): class External(_RComponent):
def __init__(self, r): def allocate(self, memo):
self.r = r # nothing to allocate
return None
def build(self, mode, memo):
return None
def pretty(self):
rval = super(External, self).pretty()
if self.r.owner:
rval += '\n= %s' % (pprint.pp2.process(self.r, dict(target = self.r)))
return rval
class Member(_RComponent):
def allocate(self, memo): def allocate(self, memo):
r = self.r r = self.r
...@@ -128,31 +134,13 @@ class Member(Component): ...@@ -128,31 +134,13 @@ class Member(Component):
def build(self, mode, memo): def build(self, mode, memo):
return memo[self.r] return memo[self.r]
def __set_name__(self, name):
super(Member, self).__set_name__(name)
self.r.name = name
def __str__(self):
return "%s(%s)" % (self.__class__.__name__, self.r)
def pretty(self):
rval = 'Member :: %s' % self.r.type
return rval
# def pretty(self, header = False, **kwargs):
# cr = '\n ' if header else '\n'
# rval = ''
# if header:
# rval += 'Member:%s' % cr
# rval += '%s :: %s' % ((self.r.name if self.r.name else '<unnamed>'), self.r.type)
# return rval
from theano.sandbox import pprint from theano.sandbox import pprint
class Method(Component): class Method(Component):
def __init__(self, inputs, outputs, updates = {}, **kwupdates): def __init__(self, inputs, outputs, updates = {}, **kwupdates):
super(Method, self).__init__()
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.updates = dict(updates, **kwupdates) self.updates = dict(updates, **kwupdates)
...@@ -264,8 +252,11 @@ class CompositeInstance(object): ...@@ -264,8 +252,11 @@ class CompositeInstance(object):
x = self.__items__[item] x = self.__items__[item]
if isinstance(x, gof.Container): if isinstance(x, gof.Container):
x.value = value x.value = value
else: elif hasattr(x, 'initialize'):
x.initialize(value) x.initialize(value)
else:
##self.__items__[item] = value
raise TypeError('Cannot set item %s' % item)
def initialize(self, init): def initialize(self, init):
for i, initv in enumerate(init): for i, initv in enumerate(init):
...@@ -328,11 +319,19 @@ class Composite(Component): ...@@ -328,11 +319,19 @@ class Composite(Component):
def __setitem__(self, item, value): def __setitem__(self, item, value):
self.set(item, value) self.set(item, value)
def __iter__(self):
return (c.r if isinstance(c, (External, Member)) else c for c in self.components())
class ComponentListInstance(CompositeInstance):
def __str__(self):
return '[%s]' % ', '.join(map(str, self.__items__))
class ComponentList(Composite): class ComponentList(Composite):
def __init__(self, *_components): def __init__(self, *_components):
super(ComponentList, self).__init__()
if len(_components) == 1 and isinstance(_components[0], (list, tuple)): if len(_components) == 1 and isinstance(_components[0], (list, tuple)):
_components = _components[0] _components = _components[0]
self._components = [] self._components = []
...@@ -360,7 +359,7 @@ class ComponentList(Composite): ...@@ -360,7 +359,7 @@ class ComponentList(Composite):
def build(self, mode, memo): def build(self, mode, memo):
builds = [c.build(mode, memo) for c in self._components] builds = [c.build(mode, memo) for c in self._components]
return CompositeInstance(self, builds) return ComponentListInstance(self, builds)
def get(self, item): def get(self, item):
return self._components[item] return self._components[item]
...@@ -401,6 +400,7 @@ class ComponentList(Composite): ...@@ -401,6 +400,7 @@ class ComponentList(Composite):
class ModuleInstance(CompositeInstance): class ModuleInstance(CompositeInstance):
__hide__ = []
def __setitem__(self, item, value): def __setitem__(self, item, value):
if item not in self.__items__: if item not in self.__items__:
...@@ -412,10 +412,22 @@ class ModuleInstance(CompositeInstance): ...@@ -412,10 +412,22 @@ class ModuleInstance(CompositeInstance):
for name, value in chain(init.iteritems(), kwinit.iteritems()): for name, value in chain(init.iteritems(), kwinit.iteritems()):
self[name] = value self[name] = value
def __str__(self):
strings = []
for k, v in sorted(self.__items__.iteritems()):
if isinstance(v, gof.Container):
v = v.value
if not k.startswith('_') and not callable(v) and not k in self.__hide__:
pre = '%s: ' % k
strings.append('%s%s' % (pre, str(v).replace('\n', '\n' + ' '*len(pre))))
return '{%s}' % '\n'.join(strings).replace('\n', '\n ')
class Module(Composite): class Module(Composite):
__instance_type__ = ModuleInstance __instance_type__ = ModuleInstance
def __init__(self, components = {}, **kwcomponents): def __init__(self, components = {}, **kwcomponents):
super(Module, self).__init__()
components = dict(components, **kwcomponents) components = dict(components, **kwcomponents)
self._components = components self._components = components
...@@ -489,20 +501,41 @@ def wrap(x): ...@@ -489,20 +501,41 @@ def wrap(x):
register_wrapper(lambda x: isinstance(x, gof.Result), register_wrapper(lambda x: isinstance(x, gof.Result),
lambda x: External(x)) lambda x: External(x))
register_wrapper(lambda x: isinstance(x, list) and all(isinstance(r, Component) for r in x), register_wrapper(lambda x: isinstance(x, (list, tuple)) and all(isinstance(r, Component) for r in x),
lambda x: ComponentList(*x)) lambda x: ComponentList(*x))
register_wrapper(lambda x: isinstance(x, list) \ register_wrapper(lambda x: isinstance(x, (list, tuple)) \
and all(isinstance(r, gof.Result) and not r.owner for r in x), and all(isinstance(r, gof.Result) and not r.owner for r in x),
lambda x: ComponentList(*map(Member, x))) lambda x: ComponentList(*map(Member, x)))
class Curry:
def __init__(self, obj, name, arg):
self.obj = obj
self.name = name
self.meth = getattr(self.obj, self.name)
self.arg = arg
def __call__(self, *args, **kwargs):
self.meth(self.arg, *args, **kwargs)
def __getstate__(self):
return [self.obj, self.name, self.arg]
def __setstate__(self, state):
self.obj, self.name, self.arg = state
self.meth = getattr(self.obj, self.name)
class FancyModuleInstance(ModuleInstance): class FancyModuleInstance(ModuleInstance):
def __getattr__(self, attr): def __getattr__(self, attr):
try:
return self[attr] return self[attr]
except KeyError:
raise AttributeError('%s has no %s attribute.' % (self.__class__, attr))
def __setattr__(self, attr, value): def __setattr__(self, attr, value):
if attr in dir(self) or attr in dir(self.__class__):
# man this sucks
self.__dict__[attr] = value
try: try:
self[attr] = value self[attr] = value
except: except:
...@@ -515,7 +548,10 @@ class FancyModule(Module): ...@@ -515,7 +548,10 @@ class FancyModule(Module):
return wrap(x) return wrap(x)
def __getattr__(self, attr): def __getattr__(self, attr):
try:
rval = self[attr] rval = self[attr]
except KeyError:
raise AttributeError('%s has no %s attribute.' % (self.__class__, attr))
if isinstance(rval, (External, Member)): if isinstance(rval, (External, Member)):
return rval.r return rval.r
return rval return rval
...@@ -540,7 +576,7 @@ class FancyModule(Module): ...@@ -540,7 +576,7 @@ class FancyModule(Module):
inst = super(FancyModule, self).build(mode, memo) inst = super(FancyModule, self).build(mode, memo)
for method in dir(self): for method in dir(self):
if method.startswith('_instance_'): if method.startswith('_instance_'):
setattr(inst, method[10:], partial(getattr(self, method), inst)) setattr(inst, method[10:], Curry(self, method, inst))
return inst return inst
def _instance_initialize(self, inst, init = {}, **kwinit): def _instance_initialize(self, inst, init = {}, **kwinit):
...@@ -552,6 +588,7 @@ class FancyModule(Module): ...@@ -552,6 +588,7 @@ class FancyModule(Module):
class KitComponent(Component): class KitComponent(Component):
def __init__(self, kit): def __init__(self, kit):
super(KitComponent, self).__init__()
self.kit = kit self.kit = kit
def allocate(self, memo): def allocate(self, memo):
...@@ -572,7 +609,7 @@ class KitComponent(Component): ...@@ -572,7 +609,7 @@ class KitComponent(Component):
return memo[self.kit] return memo[self.kit]
from theano import tensor as T from .. import tensor as T
class RModule(FancyModule): class RModule(FancyModule):
def __init__(self, components = {}, **kwcomponents): def __init__(self, components = {}, **kwcomponents):
...@@ -603,81 +640,3 @@ class RModule(FancyModule): ...@@ -603,81 +640,3 @@ class RModule(FancyModule):
if __name__ == '__main__':
from theano import tensor as T
x, y = T.scalars('xy')
s = T.scalar()
s1, s2, s3 = T.scalars('s1', 's2', 's3')
#rterm = T.random.random_integers(T.shape(s), 100, 1000)
# print T.random.sinputs
# f = compile.function([x,
# ((s, s + x + rterm), 10),
# (T.random, 10)],
# s + x)
# print f[s]
# print f(10)
# print f[s]
mod = RModule()
mod.s = Member(s)
#mod.list = ComponentList(Member(s1), Member(s2))
#mod.list = [Member(s1), Member(s2)]
mod.list = [s1, s2]
mod.inc = Method(x, s + x,
s = mod.s + x + mod.random.random_integers((), 100, 1000))
mod.dec = Method(x, s - x,
s = s - x)
mod.sadd = Method([], s1 + mod.list[1])
m = mod.random.normal([], 1., 1.)
mod.test1 = Method([], m)
mod.test2 = Method([], m)
mod.whatever = 123
mod2 = RModule()
mod2.submodule = mod
#print mod._components
#print mod
#print mod.inc.pretty()
print mod2.pretty()
inst = mod.make(s = 2, list = [900, 9000])
print '---'
print inst.test1()
print '---'
inst.seed(10)
print inst.test1()
print inst.test1()
print inst.test2()
inst.seed(10)
print inst.test1()
print inst.test2()
print inst.s
inst.seed(10)
inst.inc(3)
print inst.s
inst.dec(4)
print inst.s
print inst.list[0]
print inst.list[1]
inst.list = [1, 2]
print inst.sadd()
inst.initialize(list = [10, -17])
print inst.sadd()
...@@ -213,6 +213,8 @@ class PPrinter: ...@@ -213,6 +213,8 @@ class PPrinter:
def process(self, r, pstate = None): def process(self, r, pstate = None):
if pstate is None: if pstate is None:
pstate = PrinterState(pprinter = self) pstate = PrinterState(pprinter = self)
elif isinstance(pstate, dict):
pstate = PrinterState(pprinter = self, **pstate)
for condition, printer in self.printers: for condition, printer in self.printers:
if condition(pstate, r): if condition(pstate, r):
return printer.process(r, pstate) return printer.process(r, pstate)
...@@ -314,3 +316,7 @@ def pprinter(): ...@@ -314,3 +316,7 @@ def pprinter():
pp = pprinter() pp = pprinter()
pp2 = pprinter()
pp2.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is not r and r.name is not None,
LeafPrinter())
...@@ -1994,7 +1994,7 @@ def grad(cost, wrt, g_cost=None): ...@@ -1994,7 +1994,7 @@ def grad(cost, wrt, g_cost=None):
Tensor(dtype = p.type.dtype, broadcastable = []), Tensor(dtype = p.type.dtype, broadcastable = []),
numpy.asarray(0, dtype=p.type.dtype)) numpy.asarray(0, dtype=p.type.dtype))
if isinstance(wrt, (list, tuple)): if hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt] return [gmap.get(p, zero(p)) for p in wrt]
else: else:
return gmap.get(wrt, zero(wrt)) return gmap.get(wrt, zero(wrt))
......
...@@ -7,6 +7,9 @@ import functools ...@@ -7,6 +7,9 @@ import functools
from ..compile import SymbolicInputKit, SymbolicInput from ..compile import SymbolicInputKit, SymbolicInput
from copy import copy from copy import copy
RS = numpy.random.RandomState
class RandomFunction(gof.Op): class RandomFunction(gof.Op):
def __init__(self, fn, outtype, *args, **kwargs): def __init__(self, fn, outtype, *args, **kwargs):
...@@ -16,12 +19,7 @@ class RandomFunction(gof.Op): ...@@ -16,12 +19,7 @@ class RandomFunction(gof.Op):
args: a list of default arguments for the function args: a list of default arguments for the function
kwargs: if the 'inplace' key is there, its value will be used to determine if the op operates inplace or not kwargs: if the 'inplace' key is there, its value will be used to determine if the op operates inplace or not
""" """
self.fn = fn self.__setstate__([fn, outtype, args, kwargs])
self.outtype = outtype
self.args = tuple(tensor.as_tensor(arg) for arg in args)
self.inplace = kwargs.pop('inplace', False)
if self.inplace:
self.destroy_map = {0: [0]}
def make_node(self, r, shape, *args): def make_node(self, r, shape, *args):
""" """
...@@ -54,7 +52,14 @@ class RandomFunction(gof.Op): ...@@ -54,7 +52,14 @@ class RandomFunction(gof.Op):
if not self.inplace: if not self.inplace:
r = copy(r) r = copy(r)
rout[0] = r rout[0] = r
out[0] = self.fn(r, *(args + [shape])) rval = self.fn(r, *(args + [shape]))
if not isinstance(rval, numpy.ndarray):
out[0] = numpy.asarray(rval)
else:
out[0] = rval
def grad(self, inputs, outputs):
return [None] * len(inputs)
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) \ return type(self) == type(other) \
...@@ -66,6 +71,22 @@ class RandomFunction(gof.Op): ...@@ -66,6 +71,22 @@ class RandomFunction(gof.Op):
def __hash__(self): def __hash__(self):
return hash(self.fn) ^ hash(self.outtype) ^ hash(self.args) ^ hash(self.inplace) return hash(self.fn) ^ hash(self.outtype) ^ hash(self.args) ^ hash(self.inplace)
def __getstate__(self):
print self.state
return self.state
def __setstate__(self, state):
self.state = state
fn, outtype, args, kwargs = state
self.fn = getattr(RS, fn) if isinstance(fn, str) else fn
self.outtype = outtype
self.args = tuple(tensor.as_tensor(arg) for arg in args)
self.inplace = kwargs.pop('inplace', False)
if self.inplace:
self.destroy_map = {0: [0]}
__oplist_constructor_list = [] __oplist_constructor_list = []
"""List of functions to be listed as op constructors in the oplist (`gen_oplist`, doc/oplist.txt).""" """List of functions to be listed as op constructors in the oplist (`gen_oplist`, doc/oplist.txt)."""
...@@ -112,11 +133,9 @@ def random_function(fn, dtype, *rfargs, **rfkwargs): ...@@ -112,11 +133,9 @@ def random_function(fn, dtype, *rfargs, **rfkwargs):
return f return f
RS = numpy.random.RandomState
# we need to provide defaults for all the functions in order to infer the argument types... # we need to provide defaults for all the functions in order to infer the argument types...
uniform = random_function(RS.uniform, 'float64', 0.0, 1.0) uniform = random_function('uniform', 'float64', 0.0, 1.0)
uniform.__doc__ = """ uniform.__doc__ = """
Usage: uniform(random_state, size, low=0.0, high=1.0) Usage: uniform(random_state, size, low=0.0, high=1.0)
Sample from a uniform distribution between low and high. Sample from a uniform distribution between low and high.
...@@ -126,7 +145,7 @@ dimensions, the first argument may be a plain integer ...@@ -126,7 +145,7 @@ dimensions, the first argument may be a plain integer
to supplement the missing information. to supplement the missing information.
""" """
binomial = random_function(RS.binomial, 'int64', 1, 0.5) binomial = random_function('binomial', 'int64', 1, 0.5)
binomial.__doc__ = """ binomial.__doc__ = """
Usage: binomial(random_state, size, n=1, prob=0.5) Usage: binomial(random_state, size, n=1, prob=0.5)
Sample n times with probability of success prob for each trial, Sample n times with probability of success prob for each trial,
...@@ -137,7 +156,7 @@ dimensions, the first argument may be a plain integer ...@@ -137,7 +156,7 @@ dimensions, the first argument may be a plain integer
to supplement the missing information. to supplement the missing information.
""" """
normal = random_function(RS.normal, 'float64', 0.0, 1.0) normal = random_function('normal', 'float64', 0.0, 1.0)
normal.__doc__ = """ normal.__doc__ = """
Usage: normal(random_state, size, avg=0.0, std=1.0) Usage: normal(random_state, size, avg=0.0, std=1.0)
Sample from a normal distribution centered on avg with Sample from a normal distribution centered on avg with
...@@ -148,7 +167,7 @@ dimensions, the first argument may be a plain integer ...@@ -148,7 +167,7 @@ dimensions, the first argument may be a plain integer
to supplement the missing information. to supplement the missing information.
""" """
random_integers = random_function(RS.random_integers, 'int64', 0, 1) random_integers = random_function('random_integers', 'int64', 0, 1)
random_integers.__doc__ = """ random_integers.__doc__ = """
Usage: random_integers(random_state, size, low=0, high=1) Usage: random_integers(random_state, size, low=0, high=1)
Sample a random integer between low and high, both inclusive. Sample a random integer between low and high, both inclusive.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论