* Fixes bug with modules having empty lists (introduced previously)

* Cleaned up code and made make_mi (now make_module_instance) method a lot more generic. It will now find and "make" (i.e transfer from local_attr dictionary to properly wrapped module members) any submodule, whether its in a list, list of lists, dictionary, etc
上级 24f4ad19
......@@ -1143,13 +1143,11 @@ class Module(ComponentDict):
value=unpack_member_and_external(value)
if not hasattr(self,"local_attr"):
self.__dict__["local_attr"]={}
self.__dict__["local_attr_order"]=[]
self.__dict__["local_attr"][attr]=value
self.__dict__["local_attr_order"].append((attr, value))
self.__dict__["local_attr"][attr] = value
def build(self, mode, memo):
for k,v in list(self.local_attr_order): #.iteritems():
for k,v in self.local_attr.iteritems():
self.__setattr__(k,v)
inst = super(Module, self).build(mode, memo)
if not isinstance(inst, ModuleInstance):
......@@ -1181,41 +1179,44 @@ class Module(ComponentDict):
for name, value in chain(init.iteritems(), kwinit.iteritems()):
inst[name] = value
def make_mi(self, *args, **kwargs):
mods=[]
meth=[]#we put the method after the member to be sure of the ordering.
def make_module_instance(self, *args, **kwargs):
"""
Module's __setattr__ method hides all members under local_attr. This
method iterates over those elements and wraps them so they can be used
in a computation graph. The "wrapped" members are then set as object
attributes accessible through the dotted notation syntax (<module_name>
<dot> <member_name>). Submodules are handled recursively.
"""
# Function to go through member lists and dictionaries recursively,
# to look for submodules on which make_module_instance needs to be called
def recurse(v):
iter = enumerate(v) if isinstance(v,list) else v.iteritems()
for sk,sv in iter:
if isinstance(sv,(list,dict)):
sv = recurse(sv)
elif isinstance(sv,Module):
sv = sv.make_module_instance(args,kwargs)
v[sk] = sv
return v
for k,v in self.local_attr.iteritems():
if isinstance(v,Module):
mods.append((k, v))
v = v.make_module_instance(args,kwargs)
self[k] = self.__wrapper__(v)
elif isinstance(v,Method):
meth.append((k,v))
elif isinstance(v, list) and isinstance(v[0],Module):
temp = []
for m in v:
m=m.make_mi(args,kwargs)
m = self.__wrapper__(m)
temp.append(m)
self[k] = self.__wrapper__(temp)
self.__setitem__(k,v)
else:
v = self.__wrapper__(v)
# iterate through lists and dictionaries to wrap submodules
if isinstance(v,(list,dict)):
self[k] = self.__wrapper__(recurse(v))
try:
self[k] = v
self[k] = self.__wrapper__(v)
except:
if isinstance(v, Component):
raise
else:
self.__dict__[k] = v
# self.__setitem__(k,v)
for k,v in mods:
v=v.make_mi(args,kwargs)
v = self.__wrapper__(v)
self[k] = v
for k,v in meth:
self.__setitem__(k,v)
return self
def make(self, *args, **kwargs):
......@@ -1226,7 +1227,7 @@ class Module(ComponentDict):
arguments and the keyword arguments. If 'mode' is in the
keyword arguments it will be passed to build().
"""
self.make_mi(args,kwargs)
self.make_module_instance(args,kwargs)
mode = kwargs.pop('mode', default_mode)
rval = self.make_no_init(mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论