提交 e9fee711 authored 作者: abergeron's avatar abergeron

Merge pull request #2548 from carriepl/scan_only_nonseq_inputs

[CRASH] Error when scan has only nonseq inputs
......@@ -1347,12 +1347,6 @@ class Scan(PureOp):
return self.outputs[s:e]
def _get_inner_inps(iidx):
s = 0
if self.n_seqs > 0:
e = 1
else:
e = len(self.tap_array[0])
p = iidx
if node.inputs[iidx + 1] in self.outer_nitsot(node):
return None
if node.inputs[iidx + 1] in self.outer_non_seqs(node):
......@@ -1360,6 +1354,11 @@ class Scan(PureOp):
node.inputs[iidx + 1])
return [self.inner_non_seqs(self.inputs)[loc_idx]]
s = 0
if self.n_seqs > 0:
e = 1
else:
e = len(self.tap_array[0])
for p in xrange(iidx):
s = e
if p < self.n_seqs:
......
......@@ -365,6 +365,27 @@ class T_Scan(unittest.TestCase):
4,
numpy.int64([2, 2, 3]))
@attr('slow')
def test_only_nonseq_inputs(self):
# Compile the Theano function
n_steps=2
inp = tensor.matrix()
broadcasted_inp, _ = theano.scan(lambda x:x,
non_sequences=[inp],
n_steps=n_steps)
out = broadcasted_inp.sum()
gr = tensor.grad(out, inp)
fun = theano.function([inp], [broadcasted_inp, gr])
# Execute the Theano function and compare outputs to the expected outputs
inputs = numpy.array([[1, 2], [3, 4]])
expected_out1 = numpy.repeat(inputs[None], n_steps, axis=0)
expected_out2 = numpy.ones(inputs.shape, dtype="int8") * n_steps
out1, out2 = fun(inputs)
utt.assert_allclose(out1, expected_out1)
utt.assert_allclose(out2, expected_out2)
# simple rnn, one input, one state, weights for each; input/state
# are vectors, weights are scalars
def test_one_sequence_one_output_weights(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论