提交 8a965eb4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2727 from carriepl/scan_connection_pattern

Stop confusing inner and outer outputs in connection_pattern()
......@@ -1471,35 +1471,54 @@ class Scan(PureOp):
# input to `z_t` then `x` is an input to `z_t`.
n_outs = len(node.outputs)
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for steps in xrange(n_outs):
for iidx in xrange(n_outs):
for jidx in xrange(n_outs):
# Get the idx of the first inner input corresponding to
# that inner output
j_inp_idx = self.get_input_pos(jidx)
# Get the idx of the outer input corresponding to that
# outer output
j_inp_idx = outer_iidx_from_outer_oidx[jidx]
if j_inp_idx == -1:
# No corresponding inner input : default to what scan
# was doing in the previous version in those cases
# which *seems* to be a hack designed to avoid passing
# the condition below but it's not certain.
j_inp_idx = 0
else:
# Get the idx of the outer input corresponding to that
# inner input
j_inp_idx = outer_iidx_from_inner_iidx[j_inp_idx]
if connection_pattern[j_inp_idx][iidx] == True:
for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]:
connection_pattern[k][iidx] = True
if j_inp_idx != -1:
if connection_pattern[j_inp_idx][iidx] == True:
for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]:
connection_pattern[k][iidx] = True
node.tag.connection_pattern = connection_pattern
return connection_pattern
def get_outer_iidx_from_outer_oidx_seq(self):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th outer output
NOTE: mitmots, mitsots, sitsots and shared outputs have corresponding
outer inputs but not nitsots.
"""
nb_outer_outputs = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot + self.n_shared_outs)
result = [-1] * nb_outer_outputs
# Process mitmots, mitsots and sitsots
input_offset = 1 + self.n_seqs
output_offset = 0
for i in range(len(self.tap_array)):
result[output_offset] = input_offset
input_offset += 1
output_offset += 1
# Process shared inputs/outputs
output_offset += self.n_nit_sot
for i in range(self.n_shared_outs):
result[output_offset] = input_offset
input_offset += 1
output_offset += 1
return result
def get_outer_iidx_from_inner_iidx_seq(self):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th inner input
......
......@@ -836,8 +836,22 @@ class T_Scan(unittest.TestCase):
outputs_info=[{'initial': a0, 'taps': [-2, -1]},
{'initial': b0, 'taps': [-2, -1]}],
n_steps=2)
tensor.grad(a[-1], a0)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
scan_node = a.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [1, 2]
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 1, 2, 2]
assert(result == expected_result)
def test_connection_pattern2(self):
# This tests for a crash in connection_pattern() when a scan node
# has more than one mitmot (multiple input taps as well as
......@@ -858,6 +872,18 @@ class T_Scan(unittest.TestCase):
scan_node = g_out[0].owner.inputs[1].owner.inputs[1].owner.inputs[0].owner
connection_pattern = scan_node.op.connection_pattern(scan_node)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
scan_node = out.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [2]
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 2, 2]
assert(result == expected_result)
def test_grad_two_scans(self):
# data input & output
......@@ -1870,6 +1896,18 @@ class T_Scan(unittest.TestCase):
analytic_grad[max_err_pos],
num_grad.gx[max_err_pos]))
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
scan_node = updates.values()[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [3, -1, 4]
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 2, 3, 4, 6]
assert(result == expected_result)
def test_grad_multiple_outs_some_truncate(self):
rng = numpy.random.RandomState(utt.fetch_seed())
vW_in = asarrayX(rng.uniform(size=(2, 2), low=-.1, high=.1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论