提交 42372d05 authored 作者: --global's avatar --global

Make sure that the updates of sitsot shared variables are initially tensor variables

上级 edbf47e0
...@@ -870,7 +870,8 @@ def scan(fn, ...@@ -870,7 +870,8 @@ def scan(fn,
tensor.unbroadcast( tensor.unbroadcast(
tensor.shape_padleft(input.variable), 0), tensor.shape_padleft(input.variable), 0),
actual_n_steps)) actual_n_steps))
sit_sot_inner_outputs.append(input.update) tensor_update = tensor.as_tensor_variable(input.update)
sit_sot_inner_outputs.append(tensor_update)
# Not that pos is not a negative index. The sign of pos is used # Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the # as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan. # update rules or part of the standard outputs of scan.
......
...@@ -1363,6 +1363,19 @@ class T_Scan(unittest.TestCase): ...@@ -1363,6 +1363,19 @@ class T_Scan(unittest.TestCase):
assert_raises(TypeError, cPickle.load, open(path, "r")) assert_raises(TypeError, cPickle.load, open(path, "r"))
def test_consistent_inner_fct(self):
# Test that scan does not falsely detect inconsistencies in a valid
# inner graph
# The pickled scan op used in this test requires the use of a gpu
from theano.sandbox import cuda
if not cuda.cuda_available:
raise SkipTest('Optional package cuda disabled')
rs = theano.sandbox.rng_mrg.MRG_RandomStreams()
output, _ = theano.scan(lambda : rs.uniform((3,)), n_steps=3)
cPickle.loads(cPickle.dumps(output))
def test_cuda_gibbs_chain(self): def test_cuda_gibbs_chain(self):
from theano.sandbox import cuda from theano.sandbox import cuda
if not cuda.cuda_available: if not cuda.cuda_available:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论