提交 9d3fe98e authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Fix connection_pattern and add test

上级 017ceb9d
...@@ -1400,15 +1400,33 @@ class Scan(PureOp): ...@@ -1400,15 +1400,33 @@ class Scan(PureOp):
tmp = ils tmp = ils
if any([x is not None for x in tmp]): if any([x is not None for x in tmp]):
connection_pattern[iidx + 1][oidx] = True connection_pattern[iidx + 1][oidx] = True
# Applying Floyd-Warshall to find all paths connecting inputs to # Applying Floyd-Warshall to find all paths connecting inputs to
# outputs. Note that if `x` is an input to `y_t` and `y_tm1` is an # outputs. Note that if `x` is an input to `y_t` and `y_tm1` is an
# input to `z_t` then `x` is an input to `z_t`. # input to `z_t` then `x` is an input to `z_t`.
n_outs = len(node.outputs) n_outs = len(node.outputs)
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
for steps in xrange(n_outs): for steps in xrange(n_outs):
for iidx in xrange(n_outs): for iidx in xrange(n_outs):
for jidx in xrange(n_outs): for jidx in xrange(n_outs):
j_inp_idx = self.get_input_pos(jidx) + 1
# Get the idx of the first inner input corresponding to
# that inner output
j_inp_idx = self.get_input_pos(jidx)
if j_inp_idx == -1:
# No corresponding inner input : default to what scan
# was doing in the previous version in those cases
# which *seems* to be a hack designed to avoid passing
# the condition below but it's not certain.
j_inp_idx = 0
else:
# Get the idx of the outer input corresponding to that
# inner input
j_inp_idx = outer_iidx_from_inner_iidx[j_inp_idx]
if connection_pattern[j_inp_idx][iidx] == True: if connection_pattern[j_inp_idx][iidx] == True:
for k in xrange(len(connection_pattern)): for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]: if connection_pattern[k][jidx]:
...@@ -1417,6 +1435,42 @@ class Scan(PureOp): ...@@ -1417,6 +1435,42 @@ class Scan(PureOp):
node.tag.connection_pattern = connection_pattern node.tag.connection_pattern = connection_pattern
return connection_pattern return connection_pattern
def get_outer_iidx_from_inner_iidx_seq(self):
""" Return a sequence where the value of at the i-th position is the
index of the outer input corresponding to the i-th inner input
"""
output = []
outer_inp_idx = 1 # First outer input is timestep index, skip it
# Handle sequences inputs
for i in range(self.info['n_seqs']):
output.append(outer_inp_idx)
outer_inp_idx += 1
# Handle mitmots, mitsots and sitsots inputs
for input_taps in self.info['tap_array']:
for tap in input_taps:
output.append(outer_inp_idx)
outer_inp_idx += 1
# Handle shared inputs
for i in range(self.info['n_shared_outs']):
output.append(outer_inp_idx)
outer_inp_idx += 1
# No inner input corresponds to the outer nitsot inputs but they still
# need to be counted
outer_inp_idx += self.info['n_nit_sot']
# Handle non-sequences inputs
nb_nonseqs_inputs = len(self.inputs) - len(output)
for i in range(nb_nonseqs_inputs):
output.append(outer_inp_idx)
outer_inp_idx += 1
return output
### GRAD FUNCTION ### GRAD FUNCTION
def grad(self, inputs, dC_douts): def grad(self, inputs, dC_douts):
outs = self(*inputs) outs = self(*inputs)
......
...@@ -817,6 +817,25 @@ class T_Scan(unittest.TestCase): ...@@ -817,6 +817,25 @@ class T_Scan(unittest.TestCase):
rval = theano.function([x], y, updates=updates)(inp) rval = theano.function([x], y, updates=updates)(inp)
assert numpy.all(rval == inp[:-1]) assert numpy.all(rval == inp[:-1])
def test_connection_pattern(self):
"""Test connection_pattern() in the presence of recurrent outputs
with multiple taps.
This test refers to a bug signaled on the theano-users mailing list
on March 10 2015 by David Schneider-Joseph.
"""
def fn(a_m2, a_m1, b_m2, b_m1):
return a_m1, b_m1
a0 = theano.shared(numpy.arange(2))
b0 = theano.shared(numpy.arange(2))
(a, b), _ = theano.scan(fn,
outputs_info=[{'initial': a0, 'taps': [-2, -1]},
{'initial': b0, 'taps': [-2, -1]}],
n_steps=2)
tensor.grad(a[-1], a0)
# simple rnn, one input, one state, weights for each; input/state are # simple rnn, one input, one state, weights for each; input/state are
# vectors, weights are scalars; using shared variables and past # vectors, weights are scalars; using shared variables and past
# taps (sequences and outputs) # taps (sequences and outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论