提交 e8f5ae18 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

correct number of steps

上级 c9856df8
......@@ -1251,6 +1251,16 @@ class Scan(PureOp):
outs = self(*inputs)
if not isinstance(outs, (list, tuple)):
outs = [outs]
if self.n_nit_sot > 0:
grad_steps = self.outer_nitsot_outs(outs)[0].shape[0]
elif self.n_sit_sot > 0:
grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1
elif self.n_mit_sot > 0:
grad_steps = self.outer_mitsot_outs(outs)[0].shape[0] +\
self.mintaps[self.n_mit_mot]
else:
grad_steps = inputs[0]
rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs)
self_inputs = rval[0]
......@@ -1426,16 +1436,14 @@ class Scan(PureOp):
n_mitmot_inps += 2
if self.truncate_gradient != -1:
do_steps = tensor.minimum(inputs[0], self.truncate_gradient)
else:
do_steps = inputs[0]
grad_steps = tensor.minimum(grad_steps, self.truncate_gradient)
n_nit_sot = self.n_seqs
inner_out_nitsot = dC_dinps_t[:self.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:]
inner_inp_sitsot = dC_dXtm1s[ins_pos - self.n_seqs:]
outer_inp_sitsot = [
tensor.zeros([do_steps + 1] +
tensor.zeros([grad_steps + 1] +
[x.shape[i] for i in xrange(x.ndim)],
dtype = y.dtype)
for y, x in zip(inner_inp_sitsot,
......@@ -1457,7 +1465,7 @@ class Scan(PureOp):
info['n_sit_sot'] = n_sitsot_outs
info['n_shared_outs'] = 0
info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while
info['as_while'] = False
info['profile'] = self.profile
info['destroy_map'] = {}
if self.name:
......@@ -1466,7 +1474,7 @@ class Scan(PureOp):
info['name'] = None
info['mode'] = self.mode
outer_inputs = ([do_steps] +
outer_inputs = ([grad_steps] +
outer_inp_seqs +
outer_inp_mitmot +
outer_inp_sitsot +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论