提交 c3106ad7 authored 作者: Frederic Bastien's avatar Frederic Bastien

imported patch module2

上级 6d8b732e
from .. import gof
from ..printing import pprint
from theano import gof
from theano.printing import pprint
from collections import defaultdict
from itertools import chain
from functools import partial
......@@ -791,31 +791,29 @@ def wrap(x):
return wrapper(x)
return x
def dict_wrap(d):
for k,v in d.iteritems():
d[k]=wrap(v)
return d
# Result -> Member
register_wrapper(lambda x: isinstance(x, gof.Result) and not x.owner,
lambda x: Member(x))
# Result -> External
register_wrapper(lambda x: isinstance(x, gof.Result),
register_wrapper(lambda x: isinstance(x, gof.Result) and x.owner,
lambda x: External(x))
# [Component1, Component2, ...] -> ComponentList(Component1, Component2, ...)
register_wrapper(lambda x: isinstance(x, (list, tuple)) and all(isinstance(r, Component) for r in x),
lambda x: ComponentList(*x))
# [Result1, Result2, ...] -> ComponentList(Member(Result1), Member(Result2), ...)
# [[Result1], {Result2}, Result3...] -> ComponentList(Member(Result1), Member(Result2), ...)
register_wrapper(lambda x: isinstance(x, (list, tuple)) \
and all(isinstance(r, gof.Result) and not r.owner for r in x),
lambda x: ComponentList(*map(Member, x)))
#{ "name1":Result1,...} -> ComponentDict(Member(Result1),...)
def dict_member(d):
nd={}
for k,v in d.iteritems():
nd[k]=Member(v)
return nd
register_wrapper(lambda x: isinstance(x, dict) \
and all(isinstance(r,gof.Result) \
and not r.owner for r in x.itervalues()),
lambda x: ComponentDict(dict_member(x)))
and all(isinstance(r, (gof.Result,Component,list,
tuple, dict)) for r in x),
lambda x: ComponentList(*map(wrap, x)))
#{ "name1":{Component,Result,list,tuple,dict},...} -> ComponentDict({Component,Result,list,tuple,dict},...)
register_wrapper(lambda x: isinstance(x, dict) \
and all(isinstance(r,Component) for r in x.itervalues()),
lambda x: ComponentDict(x))
and all(isinstance(r,(Component,gof.Result,list,tuple,dict)) for r in x.itervalues()),
lambda x: ComponentDict(dict_wrap(x)))
class Curry:
def __init__(self, obj, name, arg):
......@@ -869,13 +867,9 @@ class Module(ComponentDict):
if attr == '_components' and '_components' not in self.__dict__:
self.__dict__['_components'] = {}
try:
rval = self[attr]
rval = self.__dict__["local_attr"][attr]
except KeyError:
raise AttributeError('%s has no %s attribute.' % (self.__class__, attr))
if isinstance(rval, (External, Member)):
# Special treatment for External and Member, so that
# the user may use them to build graphs more easily.
return rval.r
return rval
def __setattr__(self, attr, value):
......@@ -885,17 +879,35 @@ class Module(ComponentDict):
elif attr == 'name':
self.__set_name__(value)
return
value = self.__wrapper__(value)
try:
self[attr] = value
except:
if isinstance(value, Component):
raise
def remove_member(v):
if isinstance(v, (Member, External)):
return v.r
elif isinstance(v, (gof.Result,Method,Module)):
return v
elif isinstance(v,(int,bool)):
return v
elif isinstance(v, (list,tuple)):
return map(remove_member,v)
elif isinstance(v,dict):
for k,vv in v.iteritems():
v[k]=remove_member(vv)
return v
else:
self.__dict__[attr] = value
# raise NotImplementedError
# print "WARNING: unknow:",v
return v
value=remove_member(value)
if not hasattr(self,"local_attr"):
self.__dict__["local_attr"]={}
self.__dict__["local_attr"][attr]=value
def build(self, mode, memo):
for k,v in self.local_attr.iteritems():
self.__setattr__(k,v)
inst = super(Module, self).build(mode, memo)
for method in dir(self):
# Any method with a name like '_instance_XXX' is added to
......@@ -911,6 +923,64 @@ class Module(ComponentDict):
for name, value in chain(init.iteritems(), kwinit.iteritems()):
inst[name] = value
def make_mi(self, *args, **kwargs):
meth=[]#we put the method after the member to be sure of the ordering.
for k,v in self.local_attr.iteritems():
if isinstance(v,Module):
v=v.make_mi(args,kwargs)
if isinstance(v,Method):
meth.append((k,v))
else:
v = self.__wrapper__(v)
try:
self[k] = v
except:
if isinstance(v, Component):
raise
else:
self.__dict__[k] = v
# self.__setitem__(k,v)
for k,v in meth:
self.__setitem__(k,v)
return self
def make(self, *args, **kwargs):
"""
Allocates the necessary containers using allocate() and uses
build() to make an instance which will be returned. The
initialize() method of the instance will be called with the
arguments and the keyword arguments. If 'mode' is in the
keyword arguments it will be passed to build().
"""
self.make_mi(args,kwargs)
mode = kwargs.pop('mode', 'FAST_COMPILE')
rval = self.make_no_init(mode)
if hasattr(rval, 'initialize'):
rval.initialize(*args, **kwargs)
return rval
def __str__(self):
return self.__class__.__name__+"(%s)" % ', '.join(x for x in sorted(map(str, self.local_attr)) if x[0] != '_')
def __get_name__(self):
"""
Getter for self.name
"""
return self._name
def __set_name__(self, name):
"""
Setter for self.name
"""
self._name = name
name = property(lambda self: self.__get_name__(),
lambda self, value: self.__set_name__(value),
"Contains the name of this Component")
FancyModule = Module
FancyModuleInstance = ModuleInstance
......
......@@ -10,7 +10,7 @@ class T_test_module(unittest.TestCase):
def test_whats_up_with_submembers(self):
class Blah(FancyModule):
def __init__(self, stepsize):
super(Blah, self).__init__(self)
super(Blah, self).__init__()
self.stepsize = Member(T.value(stepsize))
x = T.dscalar()
......
......@@ -245,7 +245,7 @@ class RModule(compile.Module):
def __init__(self, components = {}, **kwcomponents):
super(RModule, self).__init__(components, **kwcomponents)
self.random = RandomKit('rkit')
self._components['_rkit'] = compile.KitComponent(self.random)
self._rkit = compile.KitComponent(self.random)
def __wrapper__(self, x):
x = compile.module.wrap(x)
......
......@@ -32,7 +32,7 @@ class T_test_module(unittest.TestCase):
"""
class B(RModule):
def __init__(self):
super(B, self).__init__(self)
super(B, self).__init__()
self.x = compile.Member(tensor.dvector())
self.r = self.random.uniform(tensor.shape(self.x))
......@@ -40,7 +40,7 @@ class T_test_module(unittest.TestCase):
self.f = compile.Method([self.x], self.r)
class E(RModule):
def __init__(self):
super(E, self).__init__(self)
super(E, self).__init__()
self.b = B()
self.f = compile.Method([self.b.x], self.b.r)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论