提交 34a69880 authored 作者: Frederic Bastien's avatar Frederic Bastien

allow mixed recursive list,tuple,dict of Component and result to be set on module.

added partial test for this.
上级 bbf519ec
...@@ -787,6 +787,11 @@ def wrap(x): ...@@ -787,6 +787,11 @@ def wrap(x):
return wrapper(x) return wrapper(x)
return x return x
def dict_wrap(d):
for k,v in d.iteritems():
d[k]=wrap(v)
return d
# Result -> Member # Result -> Member
register_wrapper(lambda x: isinstance(x, gof.Result) and not x.owner, register_wrapper(lambda x: isinstance(x, gof.Result) and not x.owner,
lambda x: Member(x)) lambda x: Member(x))
...@@ -795,27 +800,16 @@ register_wrapper(lambda x: isinstance(x, gof.Result) and not x.owner, ...@@ -795,27 +800,16 @@ register_wrapper(lambda x: isinstance(x, gof.Result) and not x.owner,
register_wrapper(lambda x: isinstance(x, gof.Result) and x.owner, register_wrapper(lambda x: isinstance(x, gof.Result) and x.owner,
lambda x: External(x)) lambda x: External(x))
# [Component1, Component2, ...] -> ComponentList(Component1, Component2, ...) # [[Result1], {Result2}, Result3...] -> ComponentList(Member(Result1), Member(Result2), ...)
register_wrapper(lambda x: isinstance(x, (list, tuple)) and all(isinstance(r, Component) for r in x),
lambda x: ComponentList(*x))
# [Result1, Result2, ...] -> ComponentList(Member(Result1), Member(Result2), ...)
register_wrapper(lambda x: isinstance(x, (list, tuple)) \ register_wrapper(lambda x: isinstance(x, (list, tuple)) \
and all(isinstance(r, gof.Result) and not r.owner for r in x), and all(isinstance(r, (gof.Result,Component,list,
lambda x: ComponentList(*map(Member, x))) tuple, dict)) for r in x),
#{ "name1":Result1,...} -> ComponentDict(Member(Result1),...) lambda x: ComponentList(*map(wrap, x)))
def dict_member(d):
nd={} #{ "name1":{Component,Result,list,tuple,dict},...} -> ComponentDict({Component,Result,list,tuple,dict},...)
for k,v in d.iteritems():
nd[k]=Member(v)
return nd
register_wrapper(lambda x: isinstance(x, dict) \
and all(isinstance(r,gof.Result) \
and not r.owner for r in x.itervalues()),
lambda x: ComponentDict(dict_member(x)))
register_wrapper(lambda x: isinstance(x, dict) \ register_wrapper(lambda x: isinstance(x, dict) \
and all(isinstance(r,Component) for r in x.itervalues()), and all(isinstance(r,(Component,gof.Result,list,tuple,dict)) for r in x.itervalues()),
lambda x: ComponentDict(x)) lambda x: ComponentDict(dict_wrap(x)))
class Curry: class Curry:
def __init__(self, obj, name, arg): def __init__(self, obj, name, arg):
......
#!/usr/bin/env python
import unittest import unittest
from theano.compile.module import * from theano.compile.module import *
import theano.tensor as T import theano.tensor as T
...@@ -174,16 +175,16 @@ class T_test_module(unittest.TestCase): ...@@ -174,16 +175,16 @@ class T_test_module(unittest.TestCase):
assert inst2.tx[0]==2 assert inst2.tx[0]==2
assert inst1.dx['x']==1 assert inst1.dx['x']==1
assert inst2.dx['x']==2 assert inst2.dx['x']==2
# assert inst1.ltx[0][0]==1#BUG: list of tuple don't work assert inst1.ltx[0][0]==1#BUG: list of tuple don't work
# assert inst2.ltx[0][0]==2#BUG: list of tuple don't work assert inst2.ltx[0][0]==2#BUG: list of tuple don't work
# assert inst1.llx[0][0]==1#BUG: list of list don't work assert inst1.llx[0][0]==1#BUG: list of list don't work
# assert inst2.llx[0][0]==2#BUG: list of list don't work assert inst2.llx[0][0]==2#BUG: list of list don't work
# assert inst1.llx[1][0]==1#BUG: list of list don't work assert inst1.llx[1][0]==1#BUG: list of list don't work
# assert inst2.llx[1][0]==2#BUG: list of list don't work assert inst2.llx[1][0]==2#BUG: list of list don't work
# assert inst1.ttx[0][0]==1#BUG: tuple of list don't work assert inst1.ttx[0][0]==1#BUG: tuple of list don't work
# assert inst2.ttx[0][0]==2#BUG: tuple of list don't work assert inst2.ttx[0][0]==2#BUG: tuple of list don't work
# assert inst1.tlx[0][0]==1#BUG: tuple of list don't work assert inst1.tlx[0][0]==1#BUG: tuple of list don't work
# assert inst2.tlx[0][0]==2#BUG: tuple of list don't work assert inst2.tlx[0][0]==2#BUG: tuple of list don't work
#m1.x and m2.x should be shared as their is a hierarchi link between them. #m1.x and m2.x should be shared as their is a hierarchi link between them.
m1.m2=m2 m1.m2=m2
...@@ -197,16 +198,16 @@ class T_test_module(unittest.TestCase): ...@@ -197,16 +198,16 @@ class T_test_module(unittest.TestCase):
assert inst.m2.tx[0]==1 assert inst.m2.tx[0]==1
assert inst.dx['x']==1 assert inst.dx['x']==1
assert inst.m2.dx['x']==1 assert inst.m2.dx['x']==1
# assert inst.llx[0][0]==1#BUG: list of list don't work assert inst.llx[0][0]==1#BUG: list of list don't work
# assert inst.m2.llx[0][0]==1#BUG: list of list don't work assert inst.m2.llx[0][0]==1#BUG: list of list don't work
# assert inst.llx[1][0]==1#BUG: list of list don't work assert inst.llx[1][0]==1#BUG: list of list don't work
# assert inst.m2.llx[1][0]==1#BUG: list of list don't work assert inst.m2.llx[1][0]==1#BUG: list of list don't work
# assert inst.ltx[0][0]==1#BUG: list of list don't work assert inst.ltx[0][0]==1#BUG: list of list don't work
# assert inst.m2.ltx[0][0]==1#BUG: list of list don't work assert inst.m2.ltx[0][0]==1#BUG: list of list don't work
# assert inst.ttx[0][0]==1#BUG: list of list don't work assert inst.ttx[0][0]==1#BUG: list of list don't work
# assert inst.m2.ttx[0][0]==1#BUG: list of list don't work assert inst.m2.ttx[0][0]==1#BUG: list of list don't work
# assert inst.tlx[0][0]==1#BUG: list of list don't work assert inst.tlx[0][0]==1#BUG: list of list don't work
# assert inst.m2.tlx[0][0]==1#BUG: list of list don't work assert inst.m2.tlx[0][0]==1#BUG: list of list don't work
inst.m2.x=2 inst.m2.x=2
assert inst.x==2 assert inst.x==2
assert inst.m2.x==2 assert inst.m2.x==2
...@@ -272,22 +273,22 @@ class T_test_module(unittest.TestCase): ...@@ -272,22 +273,22 @@ class T_test_module(unittest.TestCase):
assert inst.tz[0]()==2 assert inst.tz[0]()==2
assert inst.dy['y'](2)==4 assert inst.dy['y'](2)==4
assert inst.dz['z']()==2 assert inst.dz['z']()==2
# assert inst.lly[0][0](2)==4#BUG: we don't support list of list of Method... assert inst.lly[0][0](2)==4#BUG: we don't support list of list of Method...
# assert inst.llz[0][0]()==2 assert inst.llz[0][0]()==2
# assert inst.tty[0][0](2)==4 assert inst.tty[0][0](2)==4
# assert inst.ttz[0][0]()==2 assert inst.ttz[0][0]()==2
assert isinstance(inst.z,theano.compile.function_module.Function) assert isinstance(inst.z,theano.compile.function_module.Function)
assert isinstance(inst.lz[0],theano.compile.function_module.Function) assert isinstance(inst.lz[0],theano.compile.function_module.Function)
# assert isinstance(inst.llz[0][0],theano.compile.function_module.Function) assert isinstance(inst.llz[0][0],theano.compile.function_module.Function)
assert isinstance(inst.tz[0],theano.compile.function_module.Function) assert isinstance(inst.tz[0],theano.compile.function_module.Function)
assert isinstance(inst.dz['z'],theano.compile.function_module.Function) assert isinstance(inst.dz['z'],theano.compile.function_module.Function)
# assert isinstance(inst.ttz[0][0],theano.compile.function_module.Function) assert isinstance(inst.ttz[0][0],theano.compile.function_module.Function)
assert isinstance(inst.y,theano.compile.function_module.Function) assert isinstance(inst.y,theano.compile.function_module.Function)
assert isinstance(inst.ly[0],theano.compile.function_module.Function) assert isinstance(inst.ly[0],theano.compile.function_module.Function)
# assert isinstance(inst.lly[0][0],theano.compile.function_module.Function) assert isinstance(inst.lly[0][0],theano.compile.function_module.Function)
assert isinstance(inst.ty[0],theano.compile.function_module.Function) assert isinstance(inst.ty[0],theano.compile.function_module.Function)
assert isinstance(inst.dy['y'],theano.compile.function_module.Function) assert isinstance(inst.dy['y'],theano.compile.function_module.Function)
# assert isinstance(inst.tty[0][0],theano.compile.function_module.Function) assert isinstance(inst.tty[0][0],theano.compile.function_module.Function)
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED" print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED"
#put them in subModules, sub-sub-Modules, shared between a list and a dict, shared between #put them in subModules, sub-sub-Modules, shared between a list and a dict, shared between
...@@ -330,3 +331,11 @@ class T_test_module(unittest.TestCase): ...@@ -330,3 +331,11 @@ class T_test_module(unittest.TestCase):
As Result with more optimization?""" As Result with more optimization?"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED" print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED"
if __name__ == '__main__':
from theano.tests import main
# main(__file__[:-3])
main("test_module")
# t=T_test_module()
# t.test_shared_members()
# tests = unittest.TestLoader().loadTestsFromModule("T_test_module")
# tests.debug()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论