提交 5e11d066 authored 作者: --global's avatar --global

Reorganize helper functions in scan op

上级 5ef9fec9
......@@ -1408,37 +1408,6 @@ class Scan(PureOp):
if hasattr(node.tag, 'connection_pattern'):
return node.tag.connection_pattern
# Define helper functions
def _get_inner_outs_idx(oidx):
"""Given the index of an outer output, return the indices of the
corresponding inner output(s) in a sequence.
"""
s = 0
e = 0
for p in xrange(oidx + 1):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return range(s, e)
def _get_inner_inps_idx(outer_iidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_iidx:
inner_iidxs.append(i)
return inner_iidxs
# Obtain the connection pattern of the inner function.
inner_connect_pattern = self.inner_connection_pattern()
......@@ -1451,10 +1420,10 @@ class Scan(PureOp):
# and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected.
for outer_oidx in range(len(node.outputs)):
inner_oidxs = _get_inner_outs_idx(outer_oidx)
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
for outer_iidx in range(len(node.inputs)):
inner_iidxs = _get_inner_inps_idx(outer_iidx)
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
for inner_oidx in inner_oidxs:
for inner_iidx in inner_iidxs:
......@@ -1490,6 +1459,36 @@ class Scan(PureOp):
node.tag.connection_pattern = connection_pattern
return connection_pattern
def get_inner_oidx_from_outer_oidx(self, outer_oidx):
"""Given the index of an outer output, return the indices of the
corresponding inner output(s) in a sequence.
"""
s = 0
e = 0
for p in xrange(outer_oidx + 1):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return range(s, e)
def get_inner_iidx_from_outer_iidx(self, outer_oidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_oidx:
inner_iidxs.append(i)
return inner_iidxs
def get_outer_iidx_from_outer_oidx_seq(self):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th outer output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论