提交 8a60518e authored 作者: --global's avatar --global

Use new mappings instead of old helper functions

上级 33247199
...@@ -242,14 +242,11 @@ class Scan(PureOp): ...@@ -242,14 +242,11 @@ class Scan(PureOp):
# For every recurrent output, iterate over the associated inner # For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype # inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for outer_oidx in range(nb_recurr_outputs): for outer_oidx in range(nb_recurr_outputs):
outer_iidx = outer_iidx_from_outer_oidx[outer_oidx] inner_iidxs = self.var_mappings['inner_inp_from_outer_out'][outer_oidx]
inner_oidxs = self.var_mappings['inner_out_from_outer_out'][outer_oidx]
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs, for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs,
inner_oidxs): inner_oidxs):
...@@ -1567,10 +1564,10 @@ class Scan(PureOp): ...@@ -1567,10 +1564,10 @@ class Scan(PureOp):
# and inner outputs and, if one such pair of inner variables is # and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected. # connected than the pair of outer variables is connected.
for outer_oidx in range(len(node.outputs)): for outer_oidx in range(len(node.outputs)):
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx) inner_oidxs = self.var_mappings['inner_out_from_outer_out'][outer_oidx]
for outer_iidx in range(len(node.inputs)): for outer_iidx in range(len(node.inputs)):
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx) inner_iidxs = self.var_mappings['inner_inp_from_outer_inp'][outer_iidx]
for inner_oidx in inner_oidxs: for inner_oidx in inner_oidxs:
for inner_iidx in inner_iidxs: for inner_iidx in inner_iidxs:
...@@ -1587,7 +1584,6 @@ class Scan(PureOp): ...@@ -1587,7 +1584,6 @@ class Scan(PureOp):
# input to `z_t` then `x` is an input to `z_t`. # input to `z_t` then `x` is an input to `z_t`.
n_outs = len(node.outputs) n_outs = len(node.outputs)
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for steps in xrange(n_outs): for steps in xrange(n_outs):
for iidx in xrange(n_outs): for iidx in xrange(n_outs):
...@@ -1595,7 +1591,7 @@ class Scan(PureOp): ...@@ -1595,7 +1591,7 @@ class Scan(PureOp):
# Get the idx of the outer input corresponding to that # Get the idx of the outer input corresponding to that
# outer output # outer output
j_inp_idx = outer_iidx_from_outer_oidx[jidx] j_inp_idx = self.var_mappings["outer_inp_from_outer_out"][jidx]
if j_inp_idx != -1: if j_inp_idx != -1:
if connection_pattern[j_inp_idx][iidx] == True: if connection_pattern[j_inp_idx][iidx] == True:
...@@ -2045,9 +2041,8 @@ class Scan(PureOp): ...@@ -2045,9 +2041,8 @@ class Scan(PureOp):
if inp in theano.gof.graph.inputs([Xt]): if inp in theano.gof.graph.inputs([Xt]):
# Get the index of the outer output that to which # Get the index of the outer output that to which
# the state variable 'inp' corresponds. # the state variable 'inp' corresponds.
outer_iidx = self.get_outer_iidx_from_inner_iidx_seq()[self.n_seqs + outer_oidx = self.var_mappings['outer_out_from_inner_inp'][self.n_seqs +
pos] pos]
outer_oidx = self.get_outer_iidx_from_outer_oidx_seq().index(outer_iidx)
if not isinstance(dC_douts[outer_oidx].type, if not isinstance(dC_douts[outer_oidx].type,
DisconnectedType): DisconnectedType):
...@@ -2098,8 +2093,8 @@ class Scan(PureOp): ...@@ -2098,8 +2093,8 @@ class Scan(PureOp):
# Get the index of the first inner input corresponding to the # Get the index of the first inner input corresponding to the
# pos-ieth inner input state # pos-ieth inner input state
idxs = self.get_inner_oidx_from_inner_iidx_seq()[self.n_seqs + idxs = self.var_mappings['inner_out_from_inner_inp'][self.n_seqs +
pos] pos]
# Check if the pos-th input is associated with one of the # Check if the pos-th input is associated with one of the
# recurrent states # recurrent states
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论