提交 a37785c0 authored 作者: --global's avatar --global

Add util function to navigate between inner inputs and inner outputs

上级 944e36dd
......@@ -1750,6 +1750,47 @@ class Scan(PureOp):
return output
def get_inner_oidx_from_inner_iidx_seq(self):
""" Return a sequence where the value at the i-th position is the
sequence containing the indices of all the inner outputs associated
with the same state as the i-th inner input
"""
output = []
inner_out_idx = 0
# Handle sequence inputs
for i in range(self.info['n_seqs']):
# Inner sequences inputs are not associated with any state
output.append([])
# Handle mitmots, mitsots and sitsots states
for state_idx in range(len(self.info['tap_array'])):
nb_in_taps = len(self.info['tap_array'][state_idx])
if state_idx < self.n_mit_mot:
nb_out_taps = len(self.mit_mot_out_slices[state_idx])
else:
nb_out_taps = 1
for i in range(nb_in_taps):
output.append(range(inner_out_idx,
inner_out_idx + nb_out_taps))
inner_out_idx += nb_out_taps
# Handle shared inputs
for i in range(self.info['n_shared_outs']):
output.append([inner_out_idx])
inner_out_idx += 1
# Handle non-sequence inputs
nb_nonseqs_inputs = len(self.inputs) - len(output)
for i in range(nb_nonseqs_inputs):
# Non sequences are not associated with any state
output.append([])
return output
# GRAD FUNCTION
def grad(self, inputs, dC_douts):
outs = self(*inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论