提交 6f01cf86 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2841 from carriepl/scan_validation

[CRASH] Improve shared variable management in Scan
......@@ -870,7 +870,8 @@ def scan(fn,
tensor.unbroadcast(
tensor.shape_padleft(input.variable), 0),
actual_n_steps))
sit_sot_inner_outputs.append(input.update)
tensor_update = tensor.as_tensor_variable(input.update)
sit_sot_inner_outputs.append(tensor_update)
# Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan.
......
......@@ -1363,6 +1363,27 @@ class T_Scan(unittest.TestCase):
assert_raises(TypeError, cPickle.load, open(path, "r"))
def test_consistent_inner_fct(self):
# Test that scan does not falsely detect inconsistencies in a valid
# inner graph
# The pickled scan op used in this test requires the use of a gpu
from theano.sandbox import cuda
if not cuda.cuda_available:
raise SkipTest('Optional package cuda disabled')
rs = theano.sandbox.rng_mrg.MRG_RandomStreams(use_cuda=True)
output, _ = theano.scan(lambda : rs.uniform((3,)), n_steps=3)
cPickle.loads(cPickle.dumps(output))
# Also ensure that, after compilation, the Scan has been moved
# on the gpu
fct = theano.function([], output, mode=mode_with_gpu)
scan_nodes = self.scan_nodes_from_fct(fct)
assert len(scan_nodes) == 1
assert (scan_nodes[0].op.info.get('gpu', False) or
scan_nodes[0].op.info.get('gpua', False))
def test_cuda_gibbs_chain(self):
from theano.sandbox import cuda
if not cuda.cuda_available:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论