提交 a61eb78f authored 作者: Razvan Pascanu's avatar Razvan Pascanu

updated functions to extract different arguments of scan

Scan arguments are ordered, and depending on the index, an argument can be a sequence, a state/output or a non sequence. The scan op has methods that given the full list of arguments returns only those that represent a certain category of inputs. The old functions used to look at self.inputs, self.outputs or node.inputs/node.outputs to determine those entries. This approach was restrictive, since all functions on scan most of the time work on clones of the original arguments. The new functions now take a list of arguments (could be either the original or clones).
上级 8cc574d4
...@@ -452,25 +452,29 @@ class Scan(PureOp): ...@@ -452,25 +452,29 @@ class Scan(PureOp):
rval.lazy = False rval.lazy = False
return rval return rval
def inner_seqs(self): def inner_seqs(self, list_inputs):
return self.inputs[:self.n_seqs] # Given the list of inner inputs this function grabs those
# corresponding to sequences
return list_inputs[:self.n_seqs]
def outer_seqs(self, node): def outer_seqs(self, list_inputs):
return node.inputs[1:1 + self.n_seqs] # Given the list of outter inputs this function grabs those
# corresponding to sequences
return list_inputs[1:1 + self.n_seqs]
def inner_mitmot(self): def inner_mitmot(self, list_inputs):
n_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot]) n_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
return self.inputs[self.n_seqs: self.n_seqs + n_taps] return list_inputs[self.n_seqs: self.n_seqs + n_taps]
def outer_mitmot(self, node): def outer_mitmot(self, list_inputs):
return node.inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot] return list_inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self): def inner_mitmot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return self.outputs[:n_taps] return list_outputs[:n_taps]
def outer_mitmot_outs(self, node): def outer_mitmot_outs(self, list_outputs):
return node.outputs[:self.n_mit_mot] return list_outputs[:self.n_mit_mot]
def mitmot_taps(self): def mitmot_taps(self):
return self.tap_array[:self.n_mit_mot] return self.tap_array[:self.n_mit_mot]
...@@ -478,98 +482,98 @@ class Scan(PureOp): ...@@ -478,98 +482,98 @@ class Scan(PureOp):
def mitmot_out_taps(self): def mitmot_out_taps(self):
return self.mit_mot_out_slices[:self.n_mit_mot] return self.mit_mot_out_slices[:self.n_mit_mot]
def inner_mitsot(self): def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot]) n_mitmot_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
ntaps_upto_sit_sot = sum(len(x) for x in ntaps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
return self.inputs[self.n_seqs + n_mitmot_taps: return list_inputs[self.n_seqs + n_mitmot_taps:
self.n_seqs + ntaps_upto_sit_sot] self.n_seqs + ntaps_upto_sit_sot]
def outer_mitsot(self, node): def outer_mitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot offset = 1 + self.n_seqs + self.n_mit_mot
return node.inputs[offset:offset + self.n_mit_sot] return list_inputs[offset:offset + self.n_mit_sot]
def inner_mitsot_outs(self): def inner_mitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return self.outputs[n_taps:n_taps + self.n_mit_sot] return list_outputs[n_taps:n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, node): def outer_mitsot_outs(self, list_outputs):
return node.outputs[self.n_mit_mot: return list_outputs[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot] self.n_mit_mot + self.n_mit_sot]
def mitsot_taps(self): def mitsot_taps(self):
return self.tap_array[self.n_mit_mot: return self.tap_array[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot] self.n_mit_mot + self.n_mit_sot]
def inner_sitsot(self): def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot offset = self.n_seqs + n_taps_upto_sit_sot
return self.inputs[offset:offset + self.n_sit_sot] return list_inputs[offset:offset + self.n_sit_sot]
def outer_sitsot(self, node): def outer_sitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return node.inputs[offset:offset + self.n_sit_sot] return list_inputs[offset:offset + self.n_sit_sot]
def inner_sitsot_outs(self): def inner_sitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps offset = self.n_mit_sot + n_taps
return self.outputs[offset:offset + self.n_sit_sot] return list_outputs[offset:offset + self.n_sit_sot]
def outer_sitsot_outs(self, node): def outer_sitsot_outs(self, list_outputs):
offset = self.n_mit_mot + self.n_mit_sot offset = self.n_mit_mot + self.n_mit_sot
return node.outputs[offset:offset + self.n_sit_sot] return list_outputs[offset:offset + self.n_sit_sot]
def outer_nitsot(self, node): def outer_nitsot(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_shared_outs) self.n_sit_sot + self.n_shared_outs)
return node.inputs[offset:offset + self.n_nit_sot] return list_inputs[offset:offset + self.n_nit_sot]
def inner_nitsot_outs(self): def inner_nitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot offset = self.n_mit_sot + n_taps + self.n_sit_sot
return self.outputs[offset:offset + self.n_nit_sot] return list_outputs[offset:offset + self.n_nit_sot]
def outer_nitsot_outs(self, node): def outer_nitsot_outs(self, list_outputs):
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot) offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot)
return node.outputs[offset:offset + self.n_nit_sot] return list_outputs[offset:offset + self.n_nit_sot]
def inner_shared(self): def inner_shared(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot
return self.inputs[offset:offset + self.n_shared_outs] return list_inputs[offset:offset + self.n_shared_outs]
def outer_shared(self, node): def outer_shared(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot) self.n_sit_sot)
return node.inputs[offset:offset + self.n_shared_outs] return list_inputs[offset:offset + self.n_shared_outs]
def inner_shared_outs(self): def inner_shared_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot
return self.outputs[offset:offset + self.n_shared_outs] return list_outputs[offset:offset + self.n_shared_outs]
def outer_shared_outs(self, node): def outer_shared_outs(self, list_outputs):
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot) self.n_nit_sot)
return node.outputs[offset:offset + self.n_shared_outs] return list_outputs[offset:offset + self.n_shared_outs]
def inner_non_seqs(self): def inner_non_seqs(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
offset = (self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot + offset = (self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot +
self.n_shared_outs) self.n_shared_outs)
return self.inputs[offset:] return list_inputs[offset:]
def outer_non_seqs(self, node): def outer_non_seqs(self, list_inputs:
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_nit_sot + self.n_shared_outs) self.n_sit_sot + self.n_nit_sot + self.n_shared_outs)
return node.inputs[offset:] return list_inputs[offset:]
def execute(self, node, args, outs): def execute(self, node, args, outs):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论