提交 2e7c2590 authored 作者: ChienliMa's avatar ChienliMa

small modification

上级 6400064f
...@@ -582,6 +582,7 @@ class Function(object): ...@@ -582,6 +582,7 @@ class Function(object):
# But to be safe for now as it isn't documented and we aren't sure # But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map. # it is well tested, we don't share the part of the storage_map.
if share_memory: if share_memory:
i_o_vars = maker.fgraph.inputs + maker.fgraph.outputs
for key in storage_map.keys(): for key in storage_map.keys():
if key not in i_o_vars: if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
...@@ -635,7 +636,7 @@ class Function(object): ...@@ -635,7 +636,7 @@ class Function(object):
""" """
Assert two SharedVariable follow some restirctions: Assert two SharedVariable follow some restirctions:
1. same type 1. same type
2. same shape???? 2. same shape or dim?
""" """
assert sv_ori.type == sv_rpl.type, ( assert sv_ori.type == sv_rpl.type, (
"Type of given SharedVariable conflicts with origianl one", "Type of given SharedVariable conflicts with origianl one",
......
...@@ -280,17 +280,28 @@ class T_function(unittest.TestCase): ...@@ -280,17 +280,28 @@ class T_function(unittest.TestCase):
else: else:
self.assertFalse(here.data is there.data) self.assertFalse(here.data is there.data)
# def test_swap_SharedVariable(self): def test_swap_SharedVariable(self):
# x = T.fscalar('x') x = T.fscalar('x')
# # SharedVariable for tests, one of them has update # SharedVariable for tests, one of them has update
# y = theano.shared(value=1) y = theano.shared(value=1, name='y')
# z = theano.shared(value=2) z = theano.shared(value=2, name='z')
# out = T.tanh((x+y+2)/(x+z-0.2)**2)
# SharedVariable to replace
# # Test for different linkers y_rpl = theano.shared(value=3)
# for mode in ["FAST_RUN","FAST_COMPILE"]: z_rpl = theano.shared(value=4)
# ori = theano.function([x], [out], mode=mode,updates={z:z+1})
# cpy = ori.copy(share_memory=True) out = T.tanh((x+y+2)/(x+z-0.2)**2)
# Test for different linkers
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], [out], mode=mode,updates={z:z+1})
cpy = ori.copy(swap={'y':y_rpl, 'z':z_rpl})
# test what:
# 1. is swapped( value equal, use = or is? )
# 2. is updatable( run several time and check value)
# 3. is/isn't separated in two function
# 4. consistence in In and Variable
def test_shared_state0(self): def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论