提交 6c3d8e26 authored 作者: abergeron's avatar abergeron

Merge pull request #3002 from carriepl/scan_index_error

[CRASH] Scan index error
......@@ -229,6 +229,11 @@ class Scan(PureOp):
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key)
# Compute mappings between outer inputs, outer outputs, inner
# inputs and inner outputs to determine with variables are associated
# with the same states.
self.var_mappings = self.get_oinp_iinp_iout_oout_mappings()
def validate_inner_graph(self):
""" Perform some elementary validations on the inner graph to ensure
that it is coherent.
......@@ -237,14 +242,11 @@ class Scan(PureOp):
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for outer_oidx in range(nb_recurr_outputs):
outer_iidx = outer_iidx_from_outer_oidx[outer_oidx]
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
inner_iidxs = self.var_mappings['inner_inp_from_outer_out'][outer_oidx]
inner_oidxs = self.var_mappings['inner_out_from_outer_out'][outer_oidx]
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs,
inner_oidxs):
......@@ -303,13 +305,19 @@ class Scan(PureOp):
def __setstate__(self, d):
self.__dict__.update(d)
self.validate_inner_graph()
if "allow_gc" not in self.__dict__:
self.allow_gc = True
self.info['allow_gc'] = True
if not hasattr(self, 'gpua'):
self.gpua = False
self.info['gpua'] = False
if not hasattr(self, 'var_mappings'):
# Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated.
self.var_mappings = self.get_oinp_iinp_iout_oout_mappings()
# Ensure that the graph associated with the inner function is valid.
self.validate_inner_graph()
def make_node(self, *inputs):
"""
......@@ -1470,66 +1478,6 @@ class Scan(PureOp):
scan_outs.append((Shape_i(0)(o),) + x[1:])
return scan_outs
def get_input_pos(self, output_index):
""" For a given ``output_index``, an index in the inner outputs of
scan, find a corresponding first index in the inner inputs of scan
"""
ipos = self.n_seqs
opos = output_index
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(otaps) > opos:
return ipos
else:
opos = opos - len(otaps)
ipos += len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if opos == 0:
return ipos
else:
opos = opos - 1
ipos += len(taps)
if opos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_pos(self, input_index):
""" For a given ``input_index``, an index in the inner inputs of
scan, find a corresponding first index in the inner outputs of scan
"""
ipos = input_index
opos = 0
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(itaps) > ipos:
return opos
else:
opos += len(otaps)
ipos -= len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if len(taps) > ipos:
return opos
else:
opos += 1
ipos -= len(taps)
if ipos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_slice_idx(self, output_index):
""" For an ``output_index``, an index in the outter ouputs of scan,
find a corresponding index in the inner outputs of scan.
"""
ipos = 0
opos = output_index
for otaps in zip(self.mitmot_out_taps()):
if len(otaps) > 0:
return ipos
else:
opos = opos - 1
ipos += len(otaps)
return ipos + opos
def inner_connection_pattern(self):
""" Returns the connection pattern of scan's inner function
"""
......@@ -1616,10 +1564,10 @@ class Scan(PureOp):
# and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected.
for outer_oidx in range(len(node.outputs)):
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
inner_oidxs = self.var_mappings['inner_out_from_outer_out'][outer_oidx]
for outer_iidx in range(len(node.inputs)):
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
inner_iidxs = self.var_mappings['inner_inp_from_outer_inp'][outer_iidx]
for inner_oidx in inner_oidxs:
for inner_iidx in inner_iidxs:
......@@ -1636,7 +1584,6 @@ class Scan(PureOp):
# input to `z_t` then `x` is an input to `z_t`.
n_outs = len(node.outputs)
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for steps in xrange(n_outs):
for iidx in xrange(n_outs):
......@@ -1644,7 +1591,7 @@ class Scan(PureOp):
# Get the idx of the outer input corresponding to that
# outer output
j_inp_idx = outer_iidx_from_outer_oidx[jidx]
j_inp_idx = self.var_mappings["outer_inp_from_outer_out"][jidx]
if j_inp_idx != -1:
if connection_pattern[j_inp_idx][iidx] == True:
......@@ -1655,100 +1602,160 @@ class Scan(PureOp):
node.tag.connection_pattern = connection_pattern
return connection_pattern
def get_inner_oidx_from_outer_oidx(self, outer_oidx):
"""Given the index of an outer output, return the indices of the
corresponding inner output(s) in a sequence.
"""
s = 0
e = 0
for p in xrange(outer_oidx + 1):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return range(s, e)
def get_inner_iidx_from_outer_iidx(self, outer_oidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
def get_oinp_iinp_iout_oout_mappings(self):
""" Compute and return dictionary mappings between the inputs and
outputs of the inner function and the inputs and outputs of the Scan
node in the outer graph.
The return value is a dictionary in which the keys are the names of
the individual mappings and the values are the mapping dictionaries
themselves. In dictionaries representing mappings to outer variables,
the values are individual integer indices. In dictionaries
representing mappings to inner variables, the values are sequences of
indices because multiple inner variables can be associated with the
same state
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_oidx:
inner_iidxs.append(i)
return inner_iidxs
def get_outer_iidx_from_outer_oidx_seq(self):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th outer output
NOTE: mitmots, mitsots, sitsots and shared outputs have corresponding
outer inputs but not nitsots.
"""
nb_outer_outputs = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot + self.n_shared_outs)
result = [-1] * nb_outer_outputs
# Process mitmots, mitsots and sitsots
input_offset = 1 + self.n_seqs
output_offset = 0
for i in range(len(self.tap_array)):
result[output_offset] = input_offset
input_offset += 1
output_offset += 1
# Process shared inputs/outputs
output_offset += self.n_nit_sot
for i in range(self.n_shared_outs):
result[output_offset] = input_offset
input_offset += 1
output_offset += 1
return result
def get_outer_iidx_from_inner_iidx_seq(self):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th inner input
"""
output = []
outer_inp_idx = 1 # First outer input is timestep index, skip it
# Lists for outer variables contain individual indices, lists for
# inner variables contain sequences of indices because many inner
# variables can be associated with the same outer variable. The list
# and indices are initialized already containing the data associated
# with the timestep index, the first outer input.
outer_input_indices = [0]
inner_input_indices = [[]]
inner_output_indices = [[]]
outer_output_indices = [-1]
outer_iidx = 1
inner_iidx = 0
inner_oidx = 0
outer_oidx = 0
# Handle sequences inputs
for i in range(self.info['n_seqs']):
output.append(outer_inp_idx)
outer_inp_idx += 1
# Handle mitmots, mitsots and sitsots inputs
for input_taps in self.info['tap_array']:
for tap in input_taps:
output.append(outer_inp_idx)
outer_inp_idx += 1
# Handle shared inputs
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([])
outer_output_indices.append(-1)
outer_iidx += 1
inner_iidx += 1
inner_oidx += 0
outer_oidx += 0
# Handle mitmots, mitsots and sitsots variables
for i in range(len(self.info['tap_array'])):
nb_input_taps = len(self.info['tap_array'][i])
if i < self.n_mit_mot:
nb_output_taps = len(self.mit_mot_out_slices[i])
else:
nb_output_taps = 1
outer_input_indices.append(outer_iidx)
inner_input_indices.append(range(inner_iidx,
inner_iidx + nb_input_taps))
inner_output_indices.append(range(inner_oidx,
inner_oidx + nb_output_taps))
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += nb_input_taps
inner_oidx += nb_output_taps
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx += self.info['n_shared_outs']
# Handle nitsots variables
for i in range(self.n_nit_sot):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([])
inner_output_indices.append([inner_oidx])
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += 0
inner_oidx += 1
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx -= (self.info['n_shared_outs'] + self.n_nit_sot)
# Handle shared states
for i in range(self.info['n_shared_outs']):
output.append(outer_inp_idx)
outer_inp_idx += 1
# No inner input corresponds to the outer nitsot inputs but they still
# need to be counted
outer_inp_idx += self.info['n_nit_sot']
# Handle non-sequences inputs
nb_nonseqs_inputs = len(self.inputs) - len(output)
for i in range(nb_nonseqs_inputs):
output.append(outer_inp_idx)
outer_inp_idx += 1
return output
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([inner_oidx])
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += 1
inner_oidx += 1
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx += self.n_nit_sot
# Handle non-sequence inputs
# Note : the number of non-sequence inputs is not stored in self.info
# so it has to be inferred from the number of inner inputs that remain
# to be handled
for i in range(len(self.inputs) - inner_iidx):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([])
outer_output_indices.append(-1)
outer_iidx += 1
inner_iidx += 1
inner_oidx += 0
outer_oidx += 0
# With the global mapping inferred, the individual mappings
# can be produced
mappings = {"outer_inp_from_outer_out" : {},
"inner_inp_from_outer_out" : {},
"inner_out_from_outer_out" : {},
"inner_inp_from_outer_inp" : {},
"inner_out_from_outer_inp" : {},
"outer_out_from_outer_inp" : {},
"outer_inp_from_inner_inp" : {},
"inner_out_from_inner_inp" : {},
"outer_out_from_inner_inp" : {},
"outer_inp_from_inner_out" : {},
"inner_inp_from_inner_out" : {},
"outer_out_from_inner_out" : {}}
for (oinp, iinp, iout, oout) in zip(outer_input_indices,
inner_input_indices,
inner_output_indices,
outer_output_indices):
if oout != -1:
mappings["outer_inp_from_outer_out"][oout] = oinp
mappings["inner_inp_from_outer_out"][oout] = iinp
mappings["inner_out_from_outer_out"][oout] = iout
if oinp != -1:
mappings["inner_inp_from_outer_inp"][oinp] = iinp
mappings["inner_out_from_outer_inp"][oinp] = iout
mappings["outer_out_from_outer_inp"][oinp] = oout
for idx in iinp:
mappings["outer_inp_from_inner_inp"][idx] = oinp
mappings["inner_out_from_inner_inp"][idx] = iout
mappings["outer_out_from_inner_inp"][idx] = oout
for idx in iout:
mappings["outer_inp_from_inner_out"][idx] = oinp
mappings["inner_inp_from_inner_out"][idx] = iinp
mappings["outer_out_from_inner_out"][idx] = oout
return mappings
# GRAD FUNCTION
def grad(self, inputs, dC_douts):
......@@ -1896,10 +1903,14 @@ class Scan(PureOp):
for pos, inp in enumerate(states):
if inp in theano.gof.graph.inputs([Xt]):
oidx = self.get_output_pos(pos)
if not isinstance(dC_douts[oidx].type,
# Get the index of the outer output that to which
# the state variable 'inp' corresponds.
outer_oidx = self.var_mappings['outer_out_from_inner_inp'][self.n_seqs +
pos]
if not isinstance(dC_douts[outer_oidx].type,
DisconnectedType):
dtypes.append(dC_douts[oidx].dtype)
dtypes.append(dC_douts[outer_oidx].dtype)
if dtypes:
new_dtype = theano.scalar.upcast(*dtypes)
else:
......@@ -1943,14 +1954,25 @@ class Scan(PureOp):
# construct dX_dtm1
dC_dXtm1s = []
for pos, x in enumerate(dC_dinps_t[self.n_seqs:]):
opos = self.get_output_pos(pos)
if opos >= 0:
# Get the index of the first inner input corresponding to the
# pos-ieth inner input state
idxs = self.var_mappings['inner_out_from_inner_inp'][self.n_seqs +
pos]
# Check if the pos-th input is associated with one of the
# recurrent states
x_is_state = pos < sum([len(t) for t in self.tap_array])
if x_is_state and len(idxs) > 0:
opos = idxs[0]
dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype)
else:
dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType):
# The accumulated gradient is undefined
......
......@@ -657,19 +657,18 @@ class T_Scan(unittest.TestCase):
tensor.grad(a[-1], a0)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
# Also validate that the mappings outer_inp_from_outer_out and
# outer_inp_from_inner_inp produce the correct results
scan_node = a.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [1, 2]
result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = {0: 1, 1: 2}
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 1, 2, 2]
result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = {0: 1, 1: 1, 2: 2, 3: 2}
assert(result == expected_result)
def test_connection_pattern2(self):
# This tests for a crash in connection_pattern() when a scan node
# has more than one mitmot (multiple input taps as well as
......@@ -690,18 +689,42 @@ class T_Scan(unittest.TestCase):
scan_node = g_out[0].owner.inputs[1].owner.inputs[1].owner.inputs[0].owner
connection_pattern = scan_node.op.connection_pattern(scan_node)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
# Also validate that the mappings outer_inp_from_outer_out and
# outer_inp_from_inner_inp produce the correct results
scan_node = out.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [2]
result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = {0: 2}
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 2, 2]
result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = {0: 1, 1: 2, 2: 2}
assert(result == expected_result)
def test_grad_grad_mitsot_sitsot(self):
# Test for an index error when taking the second derivative
# through a Scan node with one sitsot and one mitsot.
def inner_fct(mitsot_m2, mitsot_m1, sitsot):
total = mitsot_m2 + mitsot_m1 + sitsot
output = total ** 2
return output, output
inputs = [tensor.matrix(), tensor.vector()]
outputs_info = [dict(initial=inputs[0], taps=[-2, -1]), inputs[1]]
scan_outputs, updates = theano.scan(fn=inner_fct,
outputs_info=outputs_info,
n_steps=5)
# Take the gradient of each output wrt its corresponding initial state
gradients = [theano.grad(scan_outputs[0].sum(), inputs[0]),
theano.grad(scan_outputs[1].sum(), inputs[1])]
# Take the gradient of the sum of gradients wrt the inputs
sum_of_grads = sum([g.sum() for g in gradients])
second_gradients = theano.grad(sum_of_grads, inputs[0])
def test_grad_two_scans(self):
# data input & output
......@@ -1680,16 +1703,16 @@ class T_Scan(unittest.TestCase):
analytic_grad[max_err_pos],
num_grad.gx[max_err_pos]))
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
# Also validate that the mappings outer_inp_from_outer_out and
# outer_inp_from_inner_inp produce the correct results
scan_node = updates.values()[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [3, -1, 4]
result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = {0: 3, 1: 5, 2: 4}
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 2, 3, 4, 6]
result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = {0: 1, 1: 2, 2: 3, 3: 4, 4: 6}
assert(result == expected_result)
def test_grad_multiple_outs_some_truncate(self):
......@@ -3299,6 +3322,69 @@ class T_Scan(unittest.TestCase):
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 0
def test_oinp_iinp_iout_oout_mappings(self):
# Test the mapping produces by
# ScanOp.get_oinp_iinp_iout_oout_mappings()
rng = theano.tensor.shared_randomstreams.RandomStreams(123)
def inner_fct(seq, mitsot, sitsot, nitsot, nseq):
random_scalar = rng.uniform((1,))[0]
total = seq + mitsot + sitsot + nitsot + nseq + random_scalar
return total, total, total
# Assemble a scan with one sequence, one mitsot, one sitsot, one nitsot
# a non-sequence and a random state to test the mappings.
seq = [tensor.vector()]
non_seq = [tensor.scalar()]
outputs_info = [dict(initial=tensor.vector(), taps=[-3, -1]),
tensor.scalar(), None]
scan_outputs, _ = theano.scan(fn=inner_fct, sequences=seq,
outputs_info=outputs_info,
non_sequences=non_seq)
# Compare the mappings with the expected values
scan_node = scan_outputs[0].owner.inputs[0].owner
mappings = scan_node.op.var_mappings
assert mappings['inner_inp_from_outer_inp'] == {0 : [], 1 : [0],
2 : [1, 2], 3 : [3],
4 : [4], 5 : [],
6 : [5]}
assert mappings['inner_out_from_outer_inp'] == {0 : [], 1 : [],
2 : [0], 3 : [1],
4 : [3], 5 : [2],
6 : []}
assert mappings['outer_out_from_outer_inp'] == {0 : -1, 1 : -1,
2 : 0, 3 : 1,
4 : 3, 5 : 2,
6 : -1}
assert mappings['outer_inp_from_inner_inp'] == {0 : 1, 1 : 2,
2 : 2, 3 : 3,
4 : 4, 5 : 6}
assert mappings['inner_out_from_inner_inp'] == {0 : [], 1 : [0],
2 : [0], 3 : [1],
4 : [3], 5 : []}
assert mappings['outer_out_from_inner_inp'] == {0 : -1, 1 : 0,
2 : 0, 3 : 1,
4 : 3, 5 : -1}
assert mappings['outer_inp_from_inner_out'] == {0 : 2, 1 : 3,
2 : 5, 3 : 4}
assert mappings['inner_inp_from_inner_out'] == {0 : [1, 2], 1 : [3],
2 : [], 3 : [4]}
assert mappings['outer_out_from_inner_out'] == {0 : 0, 1 : 1,
2 : 2, 3 : 3}
assert mappings['outer_inp_from_outer_out'] == {0 : 2, 1 : 3,
2 : 5, 3 : 4}
assert mappings['inner_inp_from_outer_out'] == {0 : [1, 2], 1 : [3],
2 : [], 3 : [4]}
assert mappings['inner_out_from_outer_out'] == {0 : [0], 1 : [1],
2 : [2], 3 : [3]}
def test_grad_duplicate_outputs(self):
# This test validates that taking the gradient of a scan, in which
# multiple outputs are the same theano variable, works.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论