提交 6ae48c67 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Stop confusing inner and outer outputs in connection_pattern()

上级 4be55f63
...@@ -1471,35 +1471,54 @@ class Scan(PureOp): ...@@ -1471,35 +1471,54 @@ class Scan(PureOp):
# input to `z_t` then `x` is an input to `z_t`. # input to `z_t` then `x` is an input to `z_t`.
n_outs = len(node.outputs) n_outs = len(node.outputs)
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq() outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for steps in xrange(n_outs): for steps in xrange(n_outs):
for iidx in xrange(n_outs): for iidx in xrange(n_outs):
for jidx in xrange(n_outs): for jidx in xrange(n_outs):
# Get the idx of the first inner input corresponding to # Get the idx of the outer input corresponding to that
# that inner output # outer output
j_inp_idx = self.get_input_pos(jidx) j_inp_idx = outer_iidx_from_outer_oidx[jidx]
if j_inp_idx == -1: if j_inp_idx != -1:
# No corresponding inner input : default to what scan if connection_pattern[j_inp_idx][iidx] == True:
# was doing in the previous version in those cases for k in xrange(len(connection_pattern)):
# which *seems* to be a hack designed to avoid passing if connection_pattern[k][jidx]:
# the condition below but it's not certain. connection_pattern[k][iidx] = True
j_inp_idx = 0
else:
# Get the idx of the outer input corresponding to that
# inner input
j_inp_idx = outer_iidx_from_inner_iidx[j_inp_idx]
if connection_pattern[j_inp_idx][iidx] == True:
for k in xrange(len(connection_pattern)):
if connection_pattern[k][jidx]:
connection_pattern[k][iidx] = True
node.tag.connection_pattern = connection_pattern node.tag.connection_pattern = connection_pattern
return connection_pattern return connection_pattern
def get_outer_iidx_from_outer_oidx_seq(self):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th outer output
NOTE: mitmots, mitsots, sitsots and shared outputs have corresponding
outer inputs but not nitsots.
"""
nb_outer_outputs = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot + self.n_shared_outs)
result = [-1] * nb_outer_outputs
# Process mitmots, mitsots and sitsots
input_offset = 1 + self.n_seqs
output_offset = 0
for i in range(len(self.tap_array)):
result[output_offset] = input_offset
input_offset += 1
output_offset += 1
# Process shared inputs/outputs
input_offset += self.n_nit_sot
for i in range(self.n_shared_outs):
result[output_offset + i] = input_offset
input_offset += 1
output_offset += 1
return result
def get_outer_iidx_from_inner_iidx_seq(self): def get_outer_iidx_from_inner_iidx_seq(self):
""" Return a sequence where the value at the i-th position is the """ Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th inner input index of the outer input corresponding to the i-th inner input
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论