提交 a61eb78f authored 作者: Razvan Pascanu's avatar Razvan Pascanu

updated functions to extract different arguments of scan

Scan arguments are ordered, and depending on the index, an argument can be a sequence, a state/output or a non sequence. The scan op has methods that given the full list of arguments returns only those that represent a certain category of inputs. The old functions used to look at self.inputs, self.outputs or node.inputs/node.outputs to determine those entries. This approach was restrictive, since all functions on scan most of the time work on clones of the original arguments. The new functions now take a list of arguments (could be either the original or clones).
上级 8cc574d4
......@@ -452,25 +452,29 @@ class Scan(PureOp):
rval.lazy = False
return rval
def inner_seqs(self):
return self.inputs[:self.n_seqs]
def inner_seqs(self, list_inputs):
# Given the list of inner inputs this function grabs those
# corresponding to sequences
return list_inputs[:self.n_seqs]
def outer_seqs(self, node):
return node.inputs[1:1 + self.n_seqs]
def outer_seqs(self, list_inputs):
# Given the list of outter inputs this function grabs those
# corresponding to sequences
return list_inputs[1:1 + self.n_seqs]
def inner_mitmot(self):
def inner_mitmot(self, list_inputs):
n_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
return self.inputs[self.n_seqs: self.n_seqs + n_taps]
return list_inputs[self.n_seqs: self.n_seqs + n_taps]
def outer_mitmot(self, node):
return node.inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot]
def outer_mitmot(self, list_inputs):
return list_inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self):
def inner_mitmot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return self.outputs[:n_taps]
return list_outputs[:n_taps]
def outer_mitmot_outs(self, node):
return node.outputs[:self.n_mit_mot]
def outer_mitmot_outs(self, list_outputs):
return list_outputs[:self.n_mit_mot]
def mitmot_taps(self):
return self.tap_array[:self.n_mit_mot]
......@@ -478,98 +482,98 @@ class Scan(PureOp):
def mitmot_out_taps(self):
return self.mit_mot_out_slices[:self.n_mit_mot]
def inner_mitsot(self):
def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
ntaps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
return self.inputs[self.n_seqs + n_mitmot_taps:
return list_inputs[self.n_seqs + n_mitmot_taps:
self.n_seqs + ntaps_upto_sit_sot]
def outer_mitsot(self, node):
def outer_mitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot
return node.inputs[offset:offset + self.n_mit_sot]
return list_inputs[offset:offset + self.n_mit_sot]
def inner_mitsot_outs(self):
def inner_mitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return self.outputs[n_taps:n_taps + self.n_mit_sot]
return list_outputs[n_taps:n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, node):
return node.outputs[self.n_mit_mot:
def outer_mitsot_outs(self, list_outputs):
return list_outputs[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]
def mitsot_taps(self):
return self.tap_array[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]
def inner_sitsot(self):
def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot
return self.inputs[offset:offset + self.n_sit_sot]
return list_inputs[offset:offset + self.n_sit_sot]
def outer_sitsot(self, node):
def outer_sitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return node.inputs[offset:offset + self.n_sit_sot]
return list_inputs[offset:offset + self.n_sit_sot]
def inner_sitsot_outs(self):
def inner_sitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps
return self.outputs[offset:offset + self.n_sit_sot]
return list_outputs[offset:offset + self.n_sit_sot]
def outer_sitsot_outs(self, node):
def outer_sitsot_outs(self, list_outputs):
offset = self.n_mit_mot + self.n_mit_sot
return node.outputs[offset:offset + self.n_sit_sot]
return list_outputs[offset:offset + self.n_sit_sot]
def outer_nitsot(self, node):
def outer_nitsot(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_shared_outs)
return node.inputs[offset:offset + self.n_nit_sot]
return list_inputs[offset:offset + self.n_nit_sot]
def inner_nitsot_outs(self):
def inner_nitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot
return self.outputs[offset:offset + self.n_nit_sot]
return list_outputs[offset:offset + self.n_nit_sot]
def outer_nitsot_outs(self, node):
def outer_nitsot_outs(self, list_outputs):
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot)
return node.outputs[offset:offset + self.n_nit_sot]
return list_outputs[offset:offset + self.n_nit_sot]
def inner_shared(self):
def inner_shared(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot
return self.inputs[offset:offset + self.n_shared_outs]
return list_inputs[offset:offset + self.n_shared_outs]
def outer_shared(self, node):
def outer_shared(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot)
return node.inputs[offset:offset + self.n_shared_outs]
return list_inputs[offset:offset + self.n_shared_outs]
def inner_shared_outs(self):
def inner_shared_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot
return self.outputs[offset:offset + self.n_shared_outs]
return list_outputs[offset:offset + self.n_shared_outs]
def outer_shared_outs(self, node):
def outer_shared_outs(self, list_outputs):
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot)
return node.outputs[offset:offset + self.n_shared_outs]
return list_outputs[offset:offset + self.n_shared_outs]
def inner_non_seqs(self):
def inner_non_seqs(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
offset = (self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot +
self.n_shared_outs)
return self.inputs[offset:]
return list_inputs[offset:]
def outer_non_seqs(self, node):
def outer_non_seqs(self, list_inputs:
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_nit_sot + self.n_shared_outs)
return node.inputs[offset:]
return list_inputs[offset:]
def execute(self, node, args, outs):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论