提交 fb3e7e66 authored 作者: James Bergstra's avatar James Bergstra

pfunc - fixing tests to new shared variables policy

上级 5960a4ea
...@@ -8,6 +8,12 @@ from theano import tensor ...@@ -8,6 +8,12 @@ from theano import tensor
from theano.compile.sharedvalue import * from theano.compile.sharedvalue import *
from theano.compile.pfunc import * from theano.compile.pfunc import *
def data_of(s):
"""Return the raw value of a shared variable"""
return s.container.storage[0]
class Test_pfunc(unittest.TestCase): class Test_pfunc(unittest.TestCase):
def test_doc(self): def test_doc(self):
...@@ -135,9 +141,11 @@ class Test_pfunc(unittest.TestCase): ...@@ -135,9 +141,11 @@ class Test_pfunc(unittest.TestCase):
def test_shared_mutable(self): def test_shared_mutable(self):
bval = numpy.arange(5) bval = numpy.arange(5)
b = shared(bval) b = shared(bval)
assert b.value is bval
b_out = b * 2 b_out = b * 2
assert b.value is not bval # shared vars copy args.
bval = data_of(b) # so we do this to get at the underlying data
# by default, shared are not mutable unless doing an explicit update # by default, shared are not mutable unless doing an explicit update
f = pfunc([], [b_out], mode='FAST_RUN') f = pfunc([], [b_out], mode='FAST_RUN')
assert (f() == numpy.arange(5) * 2).all() assert (f() == numpy.arange(5) * 2).all()
...@@ -152,6 +160,7 @@ class Test_pfunc(unittest.TestCase): ...@@ -152,6 +160,7 @@ class Test_pfunc(unittest.TestCase):
# do not depend on updates being in-place though! # do not depend on updates being in-place though!
bval = numpy.arange(5) bval = numpy.arange(5)
b.value = bval b.value = bval
bval = data_of(b)
f = pfunc([], [b_out], updates=[(b, b_out+3)], mode='FAST_RUN') f = pfunc([], [b_out], updates=[(b, b_out+3)], mode='FAST_RUN')
assert ( f() == numpy.arange(5)*2 ).all() assert ( f() == numpy.arange(5)*2 ).all()
assert (b.value == ((numpy.arange(5)*2)+3)).all() # because of the update assert (b.value == ((numpy.arange(5)*2)+3)).all() # because of the update
...@@ -479,10 +488,6 @@ class Test_pfunc(unittest.TestCase): ...@@ -479,10 +488,6 @@ class Test_pfunc(unittest.TestCase):
def data_of(s):
"""Return the raw value of a shared variable"""
return s.container.storage[0]
class Test_aliasing_rules(unittest.TestCase): class Test_aliasing_rules(unittest.TestCase):
""" """
1. Theano manages its own memory space, which typically does not overlap with the memory of 1. Theano manages its own memory space, which typically does not overlap with the memory of
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论