提交 c104defb authored 作者: james@X40's avatar james@X40

more changes to Module, more tests

上级 c31223aa
差异被折叠。
...@@ -267,6 +267,7 @@ class T_module(unittest.TestCase): ...@@ -267,6 +267,7 @@ class T_module(unittest.TestCase):
def get_element(i): def get_element(i):
return [i.x,i.lx[0],i.tx[0],i.dx['x'],i.llx[0][0], i.llx[1][0], i.ltx[0][0], i.ldx[0]['x'], i.tlx[0][0], i.tlx[0][0], i.tdx[0]['x'], i.dlx['x'][0], i.dtx['x'][0], i.ddx['x']['x']] return [i.x,i.lx[0],i.tx[0],i.dx['x'],i.llx[0][0], i.llx[1][0], i.ltx[0][0], i.ldx[0]['x'], i.tlx[0][0], i.tlx[0][0], i.tdx[0]['x'], i.dlx['x'][0], i.dtx['x'][0], i.ddx['x']['x']]
m1=Module() m1=Module()
m2=Module() m2=Module()
x=T.dscalar() x=T.dscalar()
...@@ -393,7 +394,13 @@ class T_module(unittest.TestCase): ...@@ -393,7 +394,13 @@ class T_module(unittest.TestCase):
assert isinstance(inst.dy['y'],theano.compile.function_module.Function) assert isinstance(inst.dy['y'],theano.compile.function_module.Function)
assert isinstance(inst.tty[0][0],theano.compile.function_module.Function) assert isinstance(inst.tty[0][0],theano.compile.function_module.Function)
print >> sys.stderr, "MODULE TEST IMPLEMENTED BUT WE DON'T KNOW WHAT WE WANT AS A RESULT"
assert m1.y is m1.ly[0]
assert inst.y is inst.ly[0]
assert inst.y is inst.lly[0][0]
assert inst.y is inst.ty[0]
assert inst.y is inst.tty[0][0]
assert inst.y is inst.dy['y']
def test_member_method_inputs(self): def test_member_method_inputs(self):
"""Test that module Members can be named as Method inputs, in which case the function will """Test that module Members can be named as Method inputs, in which case the function will
...@@ -416,7 +423,6 @@ class T_module(unittest.TestCase): ...@@ -416,7 +423,6 @@ class T_module(unittest.TestCase):
assert m.y == 77 assert m.y == 77
assert m.x == 1000 assert m.x == 1000
def test_member_input_flags(self): def test_member_input_flags(self):
"""Test that we can manipulate the mutable, strict, etc. flags (see SymbolicInput) of """Test that we can manipulate the mutable, strict, etc. flags (see SymbolicInput) of
Method inputs""" Method inputs"""
...@@ -448,20 +454,30 @@ class T_module(unittest.TestCase): ...@@ -448,20 +454,30 @@ class T_module(unittest.TestCase):
m.f([3, 2]) m.f([3, 2])
assert numpy.all(v0 != v0_copy) assert numpy.all(v0 != v0_copy)
def test_sanity_check_mode(self):
"""Test that Module.make(self) can take the same list of Modes that function can, so we can
debug modules"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED"
def test_member_value(self): def test_member_value(self):
"""Test that module Members of Value work correctly. As Result?""" """Test that module Members of Value work correctly. As Result?"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED" M = Module()
x = T.dscalar()
M.y = T.value(40)
M.f = Method([x], x + 2 * M.y)
m = M.make()
m.y = 80
assert m.f(20) == 180
def test_member_constant(self): def test_member_constant(self):
"""Test that module Members of Constant work correctly. """Test that module Members of Constant work correctly.
As Result with more optimization?""" As Result with more optimization?"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED" M = Module()
x = T.dscalar()
M.y = T.constant(40)
M.f = Method([x], x + 2 * M.y)
m = M.make()
try:
m.y = 77 #fail?
assert 0 #assign to constant should not have worked
except:
pass
assert m.f(20) == 100
def test_raise_NotImplemented(self): def test_raise_NotImplemented(self):
c=Component() c=Component()
...@@ -476,7 +492,7 @@ class T_module(unittest.TestCase): ...@@ -476,7 +492,7 @@ class T_module(unittest.TestCase):
self.assertRaises(NotImplementedError, c.get,"n") self.assertRaises(NotImplementedError, c.get,"n")
self.assertRaises(NotImplementedError, c.set,"n",1) self.assertRaises(NotImplementedError, c.set,"n",1)
def test_tuple_members(self): def test_tuple_members():
M = Module() M = Module()
M.a = (1,1) M.a = (1,1)
...@@ -489,6 +505,60 @@ class T_module(unittest.TestCase): ...@@ -489,6 +505,60 @@ class T_module(unittest.TestCase):
assert isinstance(M.a, tuple) assert isinstance(M.a, tuple)
def test_method_updates():
# updates work
M = Module()
M.x = T.dvector()
x = T.dvector()
xval= numpy.asarray([0, 0.5])
M.f = Method([x], M.x*4, updates={M.x:M.x * 2}, mode='FAST_COMPILE')
m = M.make(mode='FAST_RUN')
m.x = xval
m.f([9,9])
assert numpy.all(m.x == [0, 1])
assert numpy.all(xval == [0, 0.5])
# In(update) works
M = Module()
M.x = T.dvector()
x = T.dvector()
M.f = Method([x, io.In(M.x, value=xval, update=M.x*2)], M.x*4)
m = M.make()
m.f([9,9])
assert m.x is None
assert numpy.all(xval == [0, 1])
# when a result is listed explicitly and in an update, then there's a problem.
M = Module()
M.x = T.dvector()
x = T.dvector()
M.f = Method([x, io.In(M.x, value=xval, update=M.x*2)], M.x*4,
updates={M.x:M.x * 7})
try:
m = M.make()
assert False
except ValueError, e:
if str(e[0]).startswith('Result listed in both inputs and up'):
pass
else:
raise
def test_method_mode():
"""Test that Methods can override the module build mode"""
M = Module()
M.x = T.dvector()
M.f = Method([M.x], M.x*4, mode='FAST_COMPILE')
M.g = Method([M.x], M.x*4)
M.h = Method([M.x], M.x*4)
m = M.make(mode='FAST_RUN')
assert m.f.maker.mode != m.g.maker.mode
assert m.h.maker.mode == m.g.maker.mode
assert numpy.all(m.f([1,2]) == m.g([1,2]))
def test_pickle(): def test_pickle():
"""Test that a module can be pickled""" """Test that a module can be pickled"""
M = Module() M = Module()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论