* Fixes bug with modules having empty lists (introduced previously)

* Cleaned up code and made make_mi (now make_module_instance) method a lot more generic. It will now find and "make" (i.e transfer from local_attr dictionary to properly wrapped module members) any submodule, whether its in a list, list of lists, dictionary, etc
上级 24f4ad19
...@@ -1143,13 +1143,11 @@ class Module(ComponentDict): ...@@ -1143,13 +1143,11 @@ class Module(ComponentDict):
value=unpack_member_and_external(value) value=unpack_member_and_external(value)
if not hasattr(self,"local_attr"): if not hasattr(self,"local_attr"):
self.__dict__["local_attr"]={} self.__dict__["local_attr"]={}
self.__dict__["local_attr_order"]=[]
self.__dict__["local_attr"][attr] = value
self.__dict__["local_attr"][attr]=value
self.__dict__["local_attr_order"].append((attr, value))
def build(self, mode, memo): def build(self, mode, memo):
for k,v in list(self.local_attr_order): #.iteritems(): for k,v in self.local_attr.iteritems():
self.__setattr__(k,v) self.__setattr__(k,v)
inst = super(Module, self).build(mode, memo) inst = super(Module, self).build(mode, memo)
if not isinstance(inst, ModuleInstance): if not isinstance(inst, ModuleInstance):
...@@ -1181,41 +1179,44 @@ class Module(ComponentDict): ...@@ -1181,41 +1179,44 @@ class Module(ComponentDict):
for name, value in chain(init.iteritems(), kwinit.iteritems()): for name, value in chain(init.iteritems(), kwinit.iteritems()):
inst[name] = value inst[name] = value
def make_mi(self, *args, **kwargs): def make_module_instance(self, *args, **kwargs):
mods=[] """
meth=[]#we put the method after the member to be sure of the ordering. Module's __setattr__ method hides all members under local_attr. This
method iterates over those elements and wraps them so they can be used
in a computation graph. The "wrapped" members are then set as object
attributes accessible through the dotted notation syntax (<module_name>
<dot> <member_name>). Submodules are handled recursively.
"""
# Function to go through member lists and dictionaries recursively,
# to look for submodules on which make_module_instance needs to be called
def recurse(v):
iter = enumerate(v) if isinstance(v,list) else v.iteritems()
for sk,sv in iter:
if isinstance(sv,(list,dict)):
sv = recurse(sv)
elif isinstance(sv,Module):
sv = sv.make_module_instance(args,kwargs)
v[sk] = sv
return v
for k,v in self.local_attr.iteritems(): for k,v in self.local_attr.iteritems():
if isinstance(v,Module): if isinstance(v,Module):
mods.append((k, v)) v = v.make_module_instance(args,kwargs)
self[k] = self.__wrapper__(v)
elif isinstance(v,Method): elif isinstance(v,Method):
meth.append((k,v)) self.__setitem__(k,v)
elif isinstance(v, list) and isinstance(v[0],Module):
temp = []
for m in v:
m=m.make_mi(args,kwargs)
m = self.__wrapper__(m)
temp.append(m)
self[k] = self.__wrapper__(temp)
else: else:
v = self.__wrapper__(v) # iterate through lists and dictionaries to wrap submodules
if isinstance(v,(list,dict)):
self[k] = self.__wrapper__(recurse(v))
try: try:
self[k] = v self[k] = self.__wrapper__(v)
except: except:
if isinstance(v, Component): if isinstance(v, Component):
raise raise
else: else:
self.__dict__[k] = v self.__dict__[k] = v
# self.__setitem__(k,v)
for k,v in mods:
v=v.make_mi(args,kwargs)
v = self.__wrapper__(v)
self[k] = v
for k,v in meth:
self.__setitem__(k,v)
return self return self
def make(self, *args, **kwargs): def make(self, *args, **kwargs):
...@@ -1226,7 +1227,7 @@ class Module(ComponentDict): ...@@ -1226,7 +1227,7 @@ class Module(ComponentDict):
arguments and the keyword arguments. If 'mode' is in the arguments and the keyword arguments. If 'mode' is in the
keyword arguments it will be passed to build(). keyword arguments it will be passed to build().
""" """
self.make_mi(args,kwargs) self.make_module_instance(args,kwargs)
mode = kwargs.pop('mode', default_mode) mode = kwargs.pop('mode', default_mode)
rval = self.make_no_init(mode) rval = self.make_no_init(mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论