提交 c3106ad7 authored 作者: Frederic Bastien's avatar Frederic Bastien

imported patch module2

上级 6d8b732e
from .. import gof from theano import gof
from ..printing import pprint from theano.printing import pprint
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from functools import partial from functools import partial
...@@ -791,31 +791,29 @@ def wrap(x): ...@@ -791,31 +791,29 @@ def wrap(x):
return wrapper(x) return wrapper(x)
return x return x
def dict_wrap(d):
for k,v in d.iteritems():
d[k]=wrap(v)
return d
# Result -> Member
register_wrapper(lambda x: isinstance(x, gof.Result) and not x.owner,
lambda x: Member(x))
# Result -> External # Result -> External
register_wrapper(lambda x: isinstance(x, gof.Result), register_wrapper(lambda x: isinstance(x, gof.Result) and x.owner,
lambda x: External(x)) lambda x: External(x))
# [Component1, Component2, ...] -> ComponentList(Component1, Component2, ...) # [[Result1], {Result2}, Result3...] -> ComponentList(Member(Result1), Member(Result2), ...)
register_wrapper(lambda x: isinstance(x, (list, tuple)) and all(isinstance(r, Component) for r in x),
lambda x: ComponentList(*x))
# [Result1, Result2, ...] -> ComponentList(Member(Result1), Member(Result2), ...)
register_wrapper(lambda x: isinstance(x, (list, tuple)) \ register_wrapper(lambda x: isinstance(x, (list, tuple)) \
and all(isinstance(r, gof.Result) and not r.owner for r in x), and all(isinstance(r, (gof.Result,Component,list,
lambda x: ComponentList(*map(Member, x))) tuple, dict)) for r in x),
#{ "name1":Result1,...} -> ComponentDict(Member(Result1),...) lambda x: ComponentList(*map(wrap, x)))
def dict_member(d):
nd={} #{ "name1":{Component,Result,list,tuple,dict},...} -> ComponentDict({Component,Result,list,tuple,dict},...)
for k,v in d.iteritems():
nd[k]=Member(v)
return nd
register_wrapper(lambda x: isinstance(x, dict) \
and all(isinstance(r,gof.Result) \
and not r.owner for r in x.itervalues()),
lambda x: ComponentDict(dict_member(x)))
register_wrapper(lambda x: isinstance(x, dict) \ register_wrapper(lambda x: isinstance(x, dict) \
and all(isinstance(r,Component) for r in x.itervalues()), and all(isinstance(r,(Component,gof.Result,list,tuple,dict)) for r in x.itervalues()),
lambda x: ComponentDict(x)) lambda x: ComponentDict(dict_wrap(x)))
class Curry: class Curry:
def __init__(self, obj, name, arg): def __init__(self, obj, name, arg):
...@@ -869,13 +867,9 @@ class Module(ComponentDict): ...@@ -869,13 +867,9 @@ class Module(ComponentDict):
if attr == '_components' and '_components' not in self.__dict__: if attr == '_components' and '_components' not in self.__dict__:
self.__dict__['_components'] = {} self.__dict__['_components'] = {}
try: try:
rval = self[attr] rval = self.__dict__["local_attr"][attr]
except KeyError: except KeyError:
raise AttributeError('%s has no %s attribute.' % (self.__class__, attr)) raise AttributeError('%s has no %s attribute.' % (self.__class__, attr))
if isinstance(rval, (External, Member)):
# Special treatment for External and Member, so that
# the user may use them to build graphs more easily.
return rval.r
return rval return rval
def __setattr__(self, attr, value): def __setattr__(self, attr, value):
...@@ -885,17 +879,35 @@ class Module(ComponentDict): ...@@ -885,17 +879,35 @@ class Module(ComponentDict):
elif attr == 'name': elif attr == 'name':
self.__set_name__(value) self.__set_name__(value)
return return
value = self.__wrapper__(value) def remove_member(v):
try: if isinstance(v, (Member, External)):
self[attr] = value return v.r
except: elif isinstance(v, (gof.Result,Method,Module)):
if isinstance(value, Component): return v
raise elif isinstance(v,(int,bool)):
return v
elif isinstance(v, (list,tuple)):
return map(remove_member,v)
elif isinstance(v,dict):
for k,vv in v.iteritems():
v[k]=remove_member(vv)
return v
else: else:
self.__dict__[attr] = value # raise NotImplementedError
# print "WARNING: unknow:",v
return v
value=remove_member(value)
if not hasattr(self,"local_attr"):
self.__dict__["local_attr"]={}
self.__dict__["local_attr"][attr]=value
def build(self, mode, memo): def build(self, mode, memo):
for k,v in self.local_attr.iteritems():
self.__setattr__(k,v)
inst = super(Module, self).build(mode, memo) inst = super(Module, self).build(mode, memo)
for method in dir(self): for method in dir(self):
# Any method with a name like '_instance_XXX' is added to # Any method with a name like '_instance_XXX' is added to
...@@ -911,6 +923,64 @@ class Module(ComponentDict): ...@@ -911,6 +923,64 @@ class Module(ComponentDict):
for name, value in chain(init.iteritems(), kwinit.iteritems()): for name, value in chain(init.iteritems(), kwinit.iteritems()):
inst[name] = value inst[name] = value
def make_mi(self, *args, **kwargs):
meth=[]#we put the method after the member to be sure of the ordering.
for k,v in self.local_attr.iteritems():
if isinstance(v,Module):
v=v.make_mi(args,kwargs)
if isinstance(v,Method):
meth.append((k,v))
else:
v = self.__wrapper__(v)
try:
self[k] = v
except:
if isinstance(v, Component):
raise
else:
self.__dict__[k] = v
# self.__setitem__(k,v)
for k,v in meth:
self.__setitem__(k,v)
return self
def make(self, *args, **kwargs):
"""
Allocates the necessary containers using allocate() and uses
build() to make an instance which will be returned. The
initialize() method of the instance will be called with the
arguments and the keyword arguments. If 'mode' is in the
keyword arguments it will be passed to build().
"""
self.make_mi(args,kwargs)
mode = kwargs.pop('mode', 'FAST_COMPILE')
rval = self.make_no_init(mode)
if hasattr(rval, 'initialize'):
rval.initialize(*args, **kwargs)
return rval
def __str__(self):
return self.__class__.__name__+"(%s)" % ', '.join(x for x in sorted(map(str, self.local_attr)) if x[0] != '_')
def __get_name__(self):
"""
Getter for self.name
"""
return self._name
def __set_name__(self, name):
"""
Setter for self.name
"""
self._name = name
name = property(lambda self: self.__get_name__(),
lambda self, value: self.__set_name__(value),
"Contains the name of this Component")
FancyModule = Module FancyModule = Module
FancyModuleInstance = ModuleInstance FancyModuleInstance = ModuleInstance
......
...@@ -10,7 +10,7 @@ class T_test_module(unittest.TestCase): ...@@ -10,7 +10,7 @@ class T_test_module(unittest.TestCase):
def test_whats_up_with_submembers(self): def test_whats_up_with_submembers(self):
class Blah(FancyModule): class Blah(FancyModule):
def __init__(self, stepsize): def __init__(self, stepsize):
super(Blah, self).__init__(self) super(Blah, self).__init__()
self.stepsize = Member(T.value(stepsize)) self.stepsize = Member(T.value(stepsize))
x = T.dscalar() x = T.dscalar()
......
...@@ -245,7 +245,7 @@ class RModule(compile.Module): ...@@ -245,7 +245,7 @@ class RModule(compile.Module):
def __init__(self, components = {}, **kwcomponents): def __init__(self, components = {}, **kwcomponents):
super(RModule, self).__init__(components, **kwcomponents) super(RModule, self).__init__(components, **kwcomponents)
self.random = RandomKit('rkit') self.random = RandomKit('rkit')
self._components['_rkit'] = compile.KitComponent(self.random) self._rkit = compile.KitComponent(self.random)
def __wrapper__(self, x): def __wrapper__(self, x):
x = compile.module.wrap(x) x = compile.module.wrap(x)
......
...@@ -32,7 +32,7 @@ class T_test_module(unittest.TestCase): ...@@ -32,7 +32,7 @@ class T_test_module(unittest.TestCase):
""" """
class B(RModule): class B(RModule):
def __init__(self): def __init__(self):
super(B, self).__init__(self) super(B, self).__init__()
self.x = compile.Member(tensor.dvector()) self.x = compile.Member(tensor.dvector())
self.r = self.random.uniform(tensor.shape(self.x)) self.r = self.random.uniform(tensor.shape(self.x))
...@@ -40,7 +40,7 @@ class T_test_module(unittest.TestCase): ...@@ -40,7 +40,7 @@ class T_test_module(unittest.TestCase):
self.f = compile.Method([self.x], self.r) self.f = compile.Method([self.x], self.r)
class E(RModule): class E(RModule):
def __init__(self): def __init__(self):
super(E, self).__init__(self) super(E, self).__init__()
self.b = B() self.b = B()
self.f = compile.Method([self.b.x], self.b.r) self.f = compile.Method([self.b.x], self.b.r)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论