提交 6c3d8e26 authored 作者: abergeron's avatar abergeron

Merge pull request #3002 from carriepl/scan_index_error

[CRASH] Scan index error
...@@ -229,6 +229,11 @@ class Scan(PureOp): ...@@ -229,6 +229,11 @@ class Scan(PureOp):
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, []) self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
# Compute mappings between outer inputs, outer outputs, inner
# inputs and inner outputs to determine with variables are associated
# with the same states.
self.var_mappings = self.get_oinp_iinp_iout_oout_mappings()
def validate_inner_graph(self): def validate_inner_graph(self):
""" Perform some elementary validations on the inner graph to ensure """ Perform some elementary validations on the inner graph to ensure
that it is coherent. that it is coherent.
...@@ -237,14 +242,11 @@ class Scan(PureOp): ...@@ -237,14 +242,11 @@ class Scan(PureOp):
# For every recurrent output, iterate over the associated inner # For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype # inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for outer_oidx in range(nb_recurr_outputs): for outer_oidx in range(nb_recurr_outputs):
outer_iidx = outer_iidx_from_outer_oidx[outer_oidx] inner_iidxs = self.var_mappings['inner_inp_from_outer_out'][outer_oidx]
inner_oidxs = self.var_mappings['inner_out_from_outer_out'][outer_oidx]
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs, for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs,
inner_oidxs): inner_oidxs):
...@@ -303,13 +305,19 @@ class Scan(PureOp): ...@@ -303,13 +305,19 @@ class Scan(PureOp):
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
self.validate_inner_graph()
if "allow_gc" not in self.__dict__: if "allow_gc" not in self.__dict__:
self.allow_gc = True self.allow_gc = True
self.info['allow_gc'] = True self.info['allow_gc'] = True
if not hasattr(self, 'gpua'): if not hasattr(self, 'gpua'):
self.gpua = False self.gpua = False
self.info['gpua'] = False self.info['gpua'] = False
if not hasattr(self, 'var_mappings'):
# Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated.
self.var_mappings = self.get_oinp_iinp_iout_oout_mappings()
# Ensure that the graph associated with the inner function is valid.
self.validate_inner_graph()
def make_node(self, *inputs): def make_node(self, *inputs):
""" """
...@@ -1470,66 +1478,6 @@ class Scan(PureOp): ...@@ -1470,66 +1478,6 @@ class Scan(PureOp):
scan_outs.append((Shape_i(0)(o),) + x[1:]) scan_outs.append((Shape_i(0)(o),) + x[1:])
return scan_outs return scan_outs
def get_input_pos(self, output_index):
""" For a given ``output_index``, an index in the inner outputs of
scan, find a corresponding first index in the inner inputs of scan
"""
ipos = self.n_seqs
opos = output_index
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(otaps) > opos:
return ipos
else:
opos = opos - len(otaps)
ipos += len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if opos == 0:
return ipos
else:
opos = opos - 1
ipos += len(taps)
if opos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_pos(self, input_index):
""" For a given ``input_index``, an index in the inner inputs of
scan, find a corresponding first index in the inner outputs of scan
"""
ipos = input_index
opos = 0
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
if len(itaps) > ipos:
return opos
else:
opos += len(otaps)
ipos -= len(itaps)
for dx, taps in enumerate(self.mitsot_taps()):
if len(taps) > ipos:
return opos
else:
opos += 1
ipos -= len(taps)
if ipos < self.info['n_sit_sot']:
return ipos + opos
else:
return -1
def get_output_slice_idx(self, output_index):
""" For an ``output_index``, an index in the outter ouputs of scan,
find a corresponding index in the inner outputs of scan.
"""
ipos = 0
opos = output_index
for otaps in zip(self.mitmot_out_taps()):
if len(otaps) > 0:
return ipos
else:
opos = opos - 1
ipos += len(otaps)
return ipos + opos
def inner_connection_pattern(self): def inner_connection_pattern(self):
""" Returns the connection pattern of scan's inner function """ Returns the connection pattern of scan's inner function
""" """
...@@ -1616,10 +1564,10 @@ class Scan(PureOp): ...@@ -1616,10 +1564,10 @@ class Scan(PureOp):
# and inner outputs and, if one such pair of inner variables is # and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected. # connected than the pair of outer variables is connected.
for outer_oidx in range(len(node.outputs)): for outer_oidx in range(len(node.outputs)):
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx) inner_oidxs = self.var_mappings['inner_out_from_outer_out'][outer_oidx]
for outer_iidx in range(len(node.inputs)): for outer_iidx in range(len(node.inputs)):
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx) inner_iidxs = self.var_mappings['inner_inp_from_outer_inp'][outer_iidx]
for inner_oidx in inner_oidxs: for inner_oidx in inner_oidxs:
for inner_iidx in inner_iidxs: for inner_iidx in inner_iidxs:
...@@ -1636,7 +1584,6 @@ class Scan(PureOp): ...@@ -1636,7 +1584,6 @@ class Scan(PureOp):
# input to `z_t` then `x` is an input to `z_t`. # input to `z_t` then `x` is an input to `z_t`.
n_outs = len(node.outputs) n_outs = len(node.outputs)
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for steps in xrange(n_outs): for steps in xrange(n_outs):
for iidx in xrange(n_outs): for iidx in xrange(n_outs):
...@@ -1644,7 +1591,7 @@ class Scan(PureOp): ...@@ -1644,7 +1591,7 @@ class Scan(PureOp):
# Get the idx of the outer input corresponding to that # Get the idx of the outer input corresponding to that
# outer output # outer output
j_inp_idx = outer_iidx_from_outer_oidx[jidx] j_inp_idx = self.var_mappings["outer_inp_from_outer_out"][jidx]
if j_inp_idx != -1: if j_inp_idx != -1:
if connection_pattern[j_inp_idx][iidx] == True: if connection_pattern[j_inp_idx][iidx] == True:
...@@ -1655,100 +1602,160 @@ class Scan(PureOp): ...@@ -1655,100 +1602,160 @@ class Scan(PureOp):
node.tag.connection_pattern = connection_pattern node.tag.connection_pattern = connection_pattern
return connection_pattern return connection_pattern
def get_inner_oidx_from_outer_oidx(self, outer_oidx): def get_oinp_iinp_iout_oout_mappings(self):
"""Given the index of an outer output, return the indices of the """ Compute and return dictionary mappings between the inputs and
corresponding inner output(s) in a sequence. outputs of the inner function and the inputs and outputs of the Scan
""" node in the outer graph.
s = 0
e = 0 The return value is a dictionary in which the keys are the names of
for p in xrange(outer_oidx + 1): the individual mappings and the values are the mapping dictionaries
s = e themselves. In dictionaries representing mappings to outer variables,
if p < self.n_mit_mot: the values are individual integer indices. In dictionaries
e += len(self.mitmot_out_taps()[p]) representing mappings to inner variables, the values are sequences of
else: indices because multiple inner variables can be associated with the
e += 1 same state
return range(s, e)
def get_inner_iidx_from_outer_iidx(self, outer_oidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_oidx:
inner_iidxs.append(i)
return inner_iidxs
def get_outer_iidx_from_outer_oidx_seq(self):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th outer output
NOTE: mitmots, mitsots, sitsots and shared outputs have corresponding
outer inputs but not nitsots.
"""
nb_outer_outputs = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot + self.n_shared_outs)
result = [-1] * nb_outer_outputs
# Process mitmots, mitsots and sitsots
input_offset = 1 + self.n_seqs
output_offset = 0
for i in range(len(self.tap_array)):
result[output_offset] = input_offset
input_offset += 1
output_offset += 1
# Process shared inputs/outputs
output_offset += self.n_nit_sot
for i in range(self.n_shared_outs):
result[output_offset] = input_offset
input_offset += 1
output_offset += 1
return result
def get_outer_iidx_from_inner_iidx_seq(self):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th inner input
""" """
output = [] # Lists for outer variables contain individual indices, lists for
outer_inp_idx = 1 # First outer input is timestep index, skip it # inner variables contain sequences of indices because many inner
# variables can be associated with the same outer variable. The list
# and indices are initialized already containing the data associated
# with the timestep index, the first outer input.
outer_input_indices = [0]
inner_input_indices = [[]]
inner_output_indices = [[]]
outer_output_indices = [-1]
outer_iidx = 1
inner_iidx = 0
inner_oidx = 0
outer_oidx = 0
# Handle sequences inputs # Handle sequences inputs
for i in range(self.info['n_seqs']): for i in range(self.info['n_seqs']):
output.append(outer_inp_idx) outer_input_indices.append(outer_iidx)
outer_inp_idx += 1 inner_input_indices.append([inner_iidx])
inner_output_indices.append([])
# Handle mitmots, mitsots and sitsots inputs outer_output_indices.append(-1)
for input_taps in self.info['tap_array']:
for tap in input_taps: outer_iidx += 1
output.append(outer_inp_idx) inner_iidx += 1
outer_inp_idx += 1 inner_oidx += 0
outer_oidx += 0
# Handle shared inputs
# Handle mitmots, mitsots and sitsots variables
for i in range(len(self.info['tap_array'])):
nb_input_taps = len(self.info['tap_array'][i])
if i < self.n_mit_mot:
nb_output_taps = len(self.mit_mot_out_slices[i])
else:
nb_output_taps = 1
outer_input_indices.append(outer_iidx)
inner_input_indices.append(range(inner_iidx,
inner_iidx + nb_input_taps))
inner_output_indices.append(range(inner_oidx,
inner_oidx + nb_output_taps))
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += nb_input_taps
inner_oidx += nb_output_taps
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx += self.info['n_shared_outs']
# Handle nitsots variables
for i in range(self.n_nit_sot):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([])
inner_output_indices.append([inner_oidx])
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += 0
inner_oidx += 1
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx -= (self.info['n_shared_outs'] + self.n_nit_sot)
# Handle shared states
for i in range(self.info['n_shared_outs']): for i in range(self.info['n_shared_outs']):
output.append(outer_inp_idx) outer_input_indices.append(outer_iidx)
outer_inp_idx += 1 inner_input_indices.append([inner_iidx])
inner_output_indices.append([inner_oidx])
# No inner input corresponds to the outer nitsot inputs but they still outer_output_indices.append(outer_oidx)
# need to be counted
outer_inp_idx += self.info['n_nit_sot'] outer_iidx += 1
inner_iidx += 1
# Handle non-sequences inputs inner_oidx += 1
nb_nonseqs_inputs = len(self.inputs) - len(output) outer_oidx += 1
for i in range(nb_nonseqs_inputs):
output.append(outer_inp_idx) # This is needed because, for outer inputs (and for outer inputs only)
outer_inp_idx += 1 # nitsots come *after* shared variables.
outer_iidx += self.n_nit_sot
return output
# Handle non-sequence inputs
# Note : the number of non-sequence inputs is not stored in self.info
# so it has to be inferred from the number of inner inputs that remain
# to be handled
for i in range(len(self.inputs) - inner_iidx):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([])
outer_output_indices.append(-1)
outer_iidx += 1
inner_iidx += 1
inner_oidx += 0
outer_oidx += 0
# With the global mapping inferred, the individual mappings
# can be produced
mappings = {"outer_inp_from_outer_out" : {},
"inner_inp_from_outer_out" : {},
"inner_out_from_outer_out" : {},
"inner_inp_from_outer_inp" : {},
"inner_out_from_outer_inp" : {},
"outer_out_from_outer_inp" : {},
"outer_inp_from_inner_inp" : {},
"inner_out_from_inner_inp" : {},
"outer_out_from_inner_inp" : {},
"outer_inp_from_inner_out" : {},
"inner_inp_from_inner_out" : {},
"outer_out_from_inner_out" : {}}
for (oinp, iinp, iout, oout) in zip(outer_input_indices,
inner_input_indices,
inner_output_indices,
outer_output_indices):
if oout != -1:
mappings["outer_inp_from_outer_out"][oout] = oinp
mappings["inner_inp_from_outer_out"][oout] = iinp
mappings["inner_out_from_outer_out"][oout] = iout
if oinp != -1:
mappings["inner_inp_from_outer_inp"][oinp] = iinp
mappings["inner_out_from_outer_inp"][oinp] = iout
mappings["outer_out_from_outer_inp"][oinp] = oout
for idx in iinp:
mappings["outer_inp_from_inner_inp"][idx] = oinp
mappings["inner_out_from_inner_inp"][idx] = iout
mappings["outer_out_from_inner_inp"][idx] = oout
for idx in iout:
mappings["outer_inp_from_inner_out"][idx] = oinp
mappings["inner_inp_from_inner_out"][idx] = iinp
mappings["outer_out_from_inner_out"][idx] = oout
return mappings
# GRAD FUNCTION # GRAD FUNCTION
def grad(self, inputs, dC_douts): def grad(self, inputs, dC_douts):
...@@ -1896,10 +1903,14 @@ class Scan(PureOp): ...@@ -1896,10 +1903,14 @@ class Scan(PureOp):
for pos, inp in enumerate(states): for pos, inp in enumerate(states):
if inp in theano.gof.graph.inputs([Xt]): if inp in theano.gof.graph.inputs([Xt]):
oidx = self.get_output_pos(pos) # Get the index of the outer output that to which
if not isinstance(dC_douts[oidx].type, # the state variable 'inp' corresponds.
outer_oidx = self.var_mappings['outer_out_from_inner_inp'][self.n_seqs +
pos]
if not isinstance(dC_douts[outer_oidx].type,
DisconnectedType): DisconnectedType):
dtypes.append(dC_douts[oidx].dtype) dtypes.append(dC_douts[outer_oidx].dtype)
if dtypes: if dtypes:
new_dtype = theano.scalar.upcast(*dtypes) new_dtype = theano.scalar.upcast(*dtypes)
else: else:
...@@ -1943,14 +1954,25 @@ class Scan(PureOp): ...@@ -1943,14 +1954,25 @@ class Scan(PureOp):
# construct dX_dtm1 # construct dX_dtm1
dC_dXtm1s = [] dC_dXtm1s = []
for pos, x in enumerate(dC_dinps_t[self.n_seqs:]): for pos, x in enumerate(dC_dinps_t[self.n_seqs:]):
opos = self.get_output_pos(pos)
if opos >= 0: # Get the index of the first inner input corresponding to the
# pos-ieth inner input state
idxs = self.var_mappings['inner_out_from_inner_inp'][self.n_seqs +
pos]
# Check if the pos-th input is associated with one of the
# recurrent states
x_is_state = pos < sum([len(t) for t in self.tap_array])
if x_is_state and len(idxs) > 0:
opos = idxs[0]
dC_dXtm1s.append(safe_new(dC_dXts[opos])) dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype: if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \ dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype) x.astype(dC_dXts[opos].dtype)
else: else:
dC_dXtm1s.append(safe_new(x)) dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s): for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType): if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType):
# The accumulated gradient is undefined # The accumulated gradient is undefined
......
...@@ -657,19 +657,18 @@ class T_Scan(unittest.TestCase): ...@@ -657,19 +657,18 @@ class T_Scan(unittest.TestCase):
tensor.grad(a[-1], a0) tensor.grad(a[-1], a0)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq # Also validate that the mappings outer_inp_from_outer_out and
# and get_outer_iidx_from_inner_iidx_seq produce the correct results # outer_inp_from_inner_inp produce the correct results
scan_node = a.owner.inputs[0].owner scan_node = a.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq() result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = [1, 2] expected_result = {0: 1, 1: 2}
assert(result == expected_result) assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq() result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = [1, 1, 2, 2] expected_result = {0: 1, 1: 1, 2: 2, 3: 2}
assert(result == expected_result) assert(result == expected_result)
def test_connection_pattern2(self): def test_connection_pattern2(self):
# This tests for a crash in connection_pattern() when a scan node # This tests for a crash in connection_pattern() when a scan node
# has more than one mitmot (multiple input taps as well as # has more than one mitmot (multiple input taps as well as
...@@ -690,18 +689,42 @@ class T_Scan(unittest.TestCase): ...@@ -690,18 +689,42 @@ class T_Scan(unittest.TestCase):
scan_node = g_out[0].owner.inputs[1].owner.inputs[1].owner.inputs[0].owner scan_node = g_out[0].owner.inputs[1].owner.inputs[1].owner.inputs[0].owner
connection_pattern = scan_node.op.connection_pattern(scan_node) connection_pattern = scan_node.op.connection_pattern(scan_node)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq # Also validate that the mappings outer_inp_from_outer_out and
# and get_outer_iidx_from_inner_iidx_seq produce the correct results # outer_inp_from_inner_inp produce the correct results
scan_node = out.owner.inputs[0].owner scan_node = out.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq() result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = [2] expected_result = {0: 2}
assert(result == expected_result) assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq() result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = [1, 2, 2] expected_result = {0: 1, 1: 2, 2: 2}
assert(result == expected_result) assert(result == expected_result)
def test_grad_grad_mitsot_sitsot(self):
# Test for an index error when taking the second derivative
# through a Scan node with one sitsot and one mitsot.
def inner_fct(mitsot_m2, mitsot_m1, sitsot):
total = mitsot_m2 + mitsot_m1 + sitsot
output = total ** 2
return output, output
inputs = [tensor.matrix(), tensor.vector()]
outputs_info = [dict(initial=inputs[0], taps=[-2, -1]), inputs[1]]
scan_outputs, updates = theano.scan(fn=inner_fct,
outputs_info=outputs_info,
n_steps=5)
# Take the gradient of each output wrt its corresponding initial state
gradients = [theano.grad(scan_outputs[0].sum(), inputs[0]),
theano.grad(scan_outputs[1].sum(), inputs[1])]
# Take the gradient of the sum of gradients wrt the inputs
sum_of_grads = sum([g.sum() for g in gradients])
second_gradients = theano.grad(sum_of_grads, inputs[0])
def test_grad_two_scans(self): def test_grad_two_scans(self):
# data input & output # data input & output
...@@ -1680,16 +1703,16 @@ class T_Scan(unittest.TestCase): ...@@ -1680,16 +1703,16 @@ class T_Scan(unittest.TestCase):
analytic_grad[max_err_pos], analytic_grad[max_err_pos],
num_grad.gx[max_err_pos])) num_grad.gx[max_err_pos]))
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq # Also validate that the mappings outer_inp_from_outer_out and
# and get_outer_iidx_from_inner_iidx_seq produce the correct results # outer_inp_from_inner_inp produce the correct results
scan_node = updates.values()[0].owner scan_node = updates.values()[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq() result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = [3, -1, 4] expected_result = {0: 3, 1: 5, 2: 4}
assert(result == expected_result) assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq() result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = [1, 2, 3, 4, 6] expected_result = {0: 1, 1: 2, 2: 3, 3: 4, 4: 6}
assert(result == expected_result) assert(result == expected_result)
def test_grad_multiple_outs_some_truncate(self): def test_grad_multiple_outs_some_truncate(self):
...@@ -3299,6 +3322,69 @@ class T_Scan(unittest.TestCase): ...@@ -3299,6 +3322,69 @@ class T_Scan(unittest.TestCase):
if isinstance(x.op, theano.scan_module.scan_op.Scan)] if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 0 assert len(lssc) == 0
def test_oinp_iinp_iout_oout_mappings(self):
# Test the mapping produces by
# ScanOp.get_oinp_iinp_iout_oout_mappings()
rng = theano.tensor.shared_randomstreams.RandomStreams(123)
def inner_fct(seq, mitsot, sitsot, nitsot, nseq):
random_scalar = rng.uniform((1,))[0]
total = seq + mitsot + sitsot + nitsot + nseq + random_scalar
return total, total, total
# Assemble a scan with one sequence, one mitsot, one sitsot, one nitsot
# a non-sequence and a random state to test the mappings.
seq = [tensor.vector()]
non_seq = [tensor.scalar()]
outputs_info = [dict(initial=tensor.vector(), taps=[-3, -1]),
tensor.scalar(), None]
scan_outputs, _ = theano.scan(fn=inner_fct, sequences=seq,
outputs_info=outputs_info,
non_sequences=non_seq)
# Compare the mappings with the expected values
scan_node = scan_outputs[0].owner.inputs[0].owner
mappings = scan_node.op.var_mappings
assert mappings['inner_inp_from_outer_inp'] == {0 : [], 1 : [0],
2 : [1, 2], 3 : [3],
4 : [4], 5 : [],
6 : [5]}
assert mappings['inner_out_from_outer_inp'] == {0 : [], 1 : [],
2 : [0], 3 : [1],
4 : [3], 5 : [2],
6 : []}
assert mappings['outer_out_from_outer_inp'] == {0 : -1, 1 : -1,
2 : 0, 3 : 1,
4 : 3, 5 : 2,
6 : -1}
assert mappings['outer_inp_from_inner_inp'] == {0 : 1, 1 : 2,
2 : 2, 3 : 3,
4 : 4, 5 : 6}
assert mappings['inner_out_from_inner_inp'] == {0 : [], 1 : [0],
2 : [0], 3 : [1],
4 : [3], 5 : []}
assert mappings['outer_out_from_inner_inp'] == {0 : -1, 1 : 0,
2 : 0, 3 : 1,
4 : 3, 5 : -1}
assert mappings['outer_inp_from_inner_out'] == {0 : 2, 1 : 3,
2 : 5, 3 : 4}
assert mappings['inner_inp_from_inner_out'] == {0 : [1, 2], 1 : [3],
2 : [], 3 : [4]}
assert mappings['outer_out_from_inner_out'] == {0 : 0, 1 : 1,
2 : 2, 3 : 3}
assert mappings['outer_inp_from_outer_out'] == {0 : 2, 1 : 3,
2 : 5, 3 : 4}
assert mappings['inner_inp_from_outer_out'] == {0 : [1, 2], 1 : [3],
2 : [], 3 : [4]}
assert mappings['inner_out_from_outer_out'] == {0 : [0], 1 : [1],
2 : [2], 3 : [3]}
def test_grad_duplicate_outputs(self): def test_grad_duplicate_outputs(self):
# This test validates that taking the gradient of a scan, in which # This test validates that taking the gradient of a scan, in which
# multiple outputs are the same theano variable, works. # multiple outputs are the same theano variable, works.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论