提交 3299fe59 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

correct _get_inner_inps

上级 97b1df31
...@@ -1279,15 +1279,19 @@ class Scan(PureOp): ...@@ -1279,15 +1279,19 @@ class Scan(PureOp):
return self.outputs[s:e] return self.outputs[s:e]
def _get_inner_inps(iidx): def _get_inner_inps(iidx):
s = 0 s = 0
e = 1 if self.n_seqs > 0:
e = 1
else:
e = len(self.tap_array[0])
p = iidx p = iidx
if (node.inputs[iidx] in self.outer_nitsot(node) or if (node.inputs[iidx + 1] in self.outer_nitsot(node) or
node.inputs[iidx] in self.outer_shared(node)): node.inputs[iidx + 1] in self.outer_shared(node)):
return None return None
if node.inputs[iidx] in self.outer_non_seqs(node): if node.inputs[iidx + 1] in self.outer_non_seqs(node):
loc_idx = self.outer_non_seqs(node).index( loc_idx = self.outer_non_seqs(node).index(
node.inputs[iidx]) node.inputs[iidx + 1])
return [self.inner_non_seqs()[loc_idx]] return [self.inner_non_seqs(self.inputs)[loc_idx]]
for p in xrange(iidx): for p in xrange(iidx):
s = e s = e
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论