提交 7e58a6d8 authored 作者: --global's avatar --global

Avoid using ScanOp.get_output_pos()

上级 a37785c0
...@@ -1937,10 +1937,15 @@ class Scan(PureOp): ...@@ -1937,10 +1937,15 @@ class Scan(PureOp):
for pos, inp in enumerate(states): for pos, inp in enumerate(states):
if inp in theano.gof.graph.inputs([Xt]): if inp in theano.gof.graph.inputs([Xt]):
oidx = self.get_output_pos(pos) # Get the index of the outer output that to which
if not isinstance(dC_douts[oidx].type, # the state variable 'inp' corresponds.
outer_iidx = self.get_outer_iidx_from_inner_iidx_seq()[self.n_seqs +
pos]
outer_oidx = self.get_outer_iidx_from_outer_oidx_seq().index(outer_iidx)
if not isinstance(dC_douts[outer_oidx].type,
DisconnectedType): DisconnectedType):
dtypes.append(dC_douts[oidx].dtype) dtypes.append(dC_douts[outer_oidx].dtype)
if dtypes: if dtypes:
new_dtype = theano.scalar.upcast(*dtypes) new_dtype = theano.scalar.upcast(*dtypes)
else: else:
...@@ -1984,14 +1989,25 @@ class Scan(PureOp): ...@@ -1984,14 +1989,25 @@ class Scan(PureOp):
# construct dX_dtm1 # construct dX_dtm1
dC_dXtm1s = [] dC_dXtm1s = []
for pos, x in enumerate(dC_dinps_t[self.n_seqs:]): for pos, x in enumerate(dC_dinps_t[self.n_seqs:]):
opos = self.get_output_pos(pos)
if opos >= 0: # Get the index of the first outer input corresponding to the
# pos-ieth inner input state
idxs = self.get_inner_oidx_from_inner_iidx_seq()[self.n_seqs +
pos]
# Check if the pos-th input is associated with one of the
# recurrent states
x_is_state = pos < sum([len(t) for t in self.tap_array])
if x_is_state and len(idxs) > 0:
opos = idxs[0]
dC_dXtm1s.append(safe_new(dC_dXts[opos])) dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype: if hasattr(x, 'dtype') and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \ dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype) x.astype(dC_dXts[opos].dtype)
else: else:
dC_dXtm1s.append(safe_new(x)) dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s): for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType): if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType):
# The accumulated gradient is undefined # The accumulated gradient is undefined
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论