提交 6400064f authored 作者: ChienliMa's avatar ChienliMa

Start writting a test

上级 dc36c99b
...@@ -293,7 +293,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -293,7 +293,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode=mode, mode=mode,
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
else: else:
# note: pfunc will also call orig_function-- orig_function is # note: pfunc will also call orig_function -- orig_function is
# a choke point that all compilation must pass through # a choke point that all compilation must pass through
fn = pfunc(params=inputs, fn = pfunc(params=inputs,
outputs=outputs, outputs=outputs,
......
...@@ -588,7 +588,6 @@ class Function(object): ...@@ -588,7 +588,6 @@ class Function(object):
input_storage = [] input_storage = []
assert len(ins) == len(fg_cpy.inputs) assert len(ins) == len(fg_cpy.inputs)
for in_ori, in_cpy, in_v in zip(maker.inputs, ins, fg_cpy.inputs): for in_ori, in_cpy, in_v in zip(maker.inputs, ins, fg_cpy.inputs):
# Since we reuse original Out instances, copied In instances # Since we reuse original Out instances, copied In instances
# should use the original variabls as their variables and # should use the original variabls as their variables and
......
...@@ -280,6 +280,18 @@ class T_function(unittest.TestCase): ...@@ -280,6 +280,18 @@ class T_function(unittest.TestCase):
else: else:
self.assertFalse(here.data is there.data) self.assertFalse(here.data is there.data)
# def test_swap_SharedVariable(self):
# x = T.fscalar('x')
# # SharedVariable for tests, one of them has update
# y = theano.shared(value=1)
# z = theano.shared(value=2)
# out = T.tanh((x+y+2)/(x+z-0.2)**2)
# # Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
# ori = theano.function([x], [out], mode=mode,updates={z:z+1})
# cpy = ori.copy(share_memory=True)
def test_shared_state0(self): def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
x, s = T.scalars('xs') x, s = T.scalars('xs')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论