提交 56830cb5 authored 作者: James Bergstra's avatar James Bergstra

added a pickling test

上级 49fc30bd
...@@ -2,8 +2,8 @@ import unittest ...@@ -2,8 +2,8 @@ import unittest
from theano import gof from theano import gof
from theano import compile from theano import compile
from theano.compile.function_module import *
from theano.scalar import * from theano.scalar import *
from theano.compile.function_module import *
from theano import tensor from theano import tensor
from theano import tensor as T from theano import tensor as T
...@@ -313,7 +313,7 @@ class T_function(unittest.TestCase): ...@@ -313,7 +313,7 @@ class T_function(unittest.TestCase):
f = function([x, In(a, value=1.0,name='a'), In(s, value=0.0, update=s+a*x, mutable=True)], s+a*x) f = function([x, In(a, value=1.0,name='a'), In(s, value=0.0, update=s+a*x, mutable=True)], s+a*x)
g = copy(f) g = copy.copy(f)
#if they both return, assume that they return equivalent things. #if they both return, assume that they return equivalent things.
self.failIf(g.container[x].storage is f.container[x].storage) self.failIf(g.container[x].storage is f.container[x].storage)
...@@ -374,6 +374,31 @@ class T_function(unittest.TestCase): ...@@ -374,6 +374,31 @@ class T_function(unittest.TestCase):
self.failUnless(f[s] == 4) self.failUnless(f[s] == 4)
self.failUnless(g[s] == 4) self.failUnless(g[s] == 4)
class T_picklefunction(unittest.TestCase):
def test_deepcopy(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
x,s = T.scalars('xs')
f = function([x, In(a, value=1.0,name='a'), In(s, value=0.0, update=s+a*x, mutable=True)], s+a*x)
print f.maker.function_builder
g = copy.deepcopy(f)
#if they both return, assume that they return equivalent things.
self.failIf(g.container[x].storage is f.container[x].storage)
self.failIf(g.container[a].storage is f.container[a].storage)
self.failIf(g.container[s].storage is f.container[s].storage)
self.failIf(g.value[a] is f.value[a]) # should not have been copied
self.failIf(g.value[s] is f.value[s]) # should have been copied because it is mutable.
self.failIf((g.value[s] != f.value[s]).any()) # its contents should be identical
self.failUnless(f(2, 1) == g(2)) #they should be in sync, default value should be copied.
self.failUnless(f(2, 1) == g(2)) #they should be in sync, default value should be copied.
f(1,2) # put them out of sync
self.failIf(f(1, 2) == g(1, 2)) #they should not be equal anymore.
# class T_function_examples(unittest.TestCase): # class T_function_examples(unittest.TestCase):
# def test_accumulator(self): # def test_accumulator(self):
# """Test low-level interface with state.""" # """Test low-level interface with state."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论