提交 50c3fd00 authored 作者: carriepl's avatar carriepl

Add equivalency between inner inputs and outer inputs for infer_shape

上级 79597a04
......@@ -1566,23 +1566,46 @@ class Scan(PureOp):
# Infer Shape
def infer_shape(self, node, input_shapes):
# input_shapes correspond to the shapes of node.inputs
# Here, we build a list inner_ins_shape, such that inner_ins_shape[i]
# is the shape of self.inputs[i]
for inp, inp_shp in izip(node.inputs, input_shapes):
assert inp_shp is None or len(inp_shp) == inp.type.ndim
# sequences
# We skip iputs_shapes[0] as it is the total or current number
# Here we build 2 variables;
# - A list `inner_ins_shapes`, such that inner_ins_shapes[i] is the
# shape of self.inputs[i]
# - A dictionary `out_equivalent` containing, for every inner input,
# an equivalent variable computed from the outer inputs.
# NOTE : For non-sequences, this equivalence is trivial. For
# sequences and recurrent states, there is no direct equivalence
# between outer and inner inputs. However, because every iteration
# of the Scan needs to give the same output shapes, we can give an
# equivalence between these inner inputs and the subelements of the
# corresponding outer inputs that the Scan would use as input for
# any given iteration. For simplicity, we use iteration 0.
inner_ins_shapes = []
out_equivalent = OrderedDict()
# We skip the first outer input as it is the total or current number
# of iterations.
# sequences
seqs_shape = [x[1:] for x in input_shapes[1:1 + self.n_seqs]]
inner_seqs = self.inputs[:self.n_seqs]
outer_seqs = node.inputs[1:1 + self.n_seqs]
for in_s, out_s in izip(inner_seqs, outer_seqs):
out_equivalent[in_s] = out_s[0]
# mit_mot, mit_sot, sit_sot
outer_inp_idx = 1 + self.n_seqs
inner_inp_idx = self.n_seqs
n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outs_shape = []
for idx in xrange(n_outs):
mintap = abs(min(self.tap_array[idx]))
for k in self.tap_array[idx]:
outs_shape += [input_shapes[idx + self.n_seqs + 1][1:]]
corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
inner_inp_idx += 1
outer_inp_idx += 1
# shared_outs
offset = 1 + self.n_seqs + n_outs
......@@ -1597,9 +1620,9 @@ class Scan(PureOp):
# Non-sequences have a direct equivalent from self.inputs in
# node.inputs
inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):]
out_equivalent = OrderedDict()
for in_ns, out_ns in izip(inner_non_sequences, node.inputs[offset:]):
out_equivalent[in_ns] = out_ns
if self.as_while:
self_outs = self.outputs[:-1]
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论