提交 e8f5ae18 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

correct number of steps

上级 c9856df8
...@@ -1251,6 +1251,16 @@ class Scan(PureOp): ...@@ -1251,6 +1251,16 @@ class Scan(PureOp):
outs = self(*inputs) outs = self(*inputs)
if not isinstance(outs, (list, tuple)): if not isinstance(outs, (list, tuple)):
outs = [outs] outs = [outs]
if self.n_nit_sot > 0:
grad_steps = self.outer_nitsot_outs(outs)[0].shape[0]
elif self.n_sit_sot > 0:
grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1
elif self.n_mit_sot > 0:
grad_steps = self.outer_mitsot_outs(outs)[0].shape[0] +\
self.mintaps[self.n_mit_mot]
else:
grad_steps = inputs[0]
rval = scan_utils.reconstruct_graph(self.inputs, rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs) self.outputs)
self_inputs = rval[0] self_inputs = rval[0]
...@@ -1426,16 +1436,14 @@ class Scan(PureOp): ...@@ -1426,16 +1436,14 @@ class Scan(PureOp):
n_mitmot_inps += 2 n_mitmot_inps += 2
if self.truncate_gradient != -1: if self.truncate_gradient != -1:
do_steps = tensor.minimum(inputs[0], self.truncate_gradient) grad_steps = tensor.minimum(grad_steps, self.truncate_gradient)
else:
do_steps = inputs[0]
n_nit_sot = self.n_seqs n_nit_sot = self.n_seqs
inner_out_nitsot = dC_dinps_t[:self.n_seqs] inner_out_nitsot = dC_dinps_t[:self.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:] inner_out_sitsot = dC_dinps_t[ins_pos:]
inner_inp_sitsot = dC_dXtm1s[ins_pos - self.n_seqs:] inner_inp_sitsot = dC_dXtm1s[ins_pos - self.n_seqs:]
outer_inp_sitsot = [ outer_inp_sitsot = [
tensor.zeros([do_steps + 1] + tensor.zeros([grad_steps + 1] +
[x.shape[i] for i in xrange(x.ndim)], [x.shape[i] for i in xrange(x.ndim)],
dtype = y.dtype) dtype = y.dtype)
for y, x in zip(inner_inp_sitsot, for y, x in zip(inner_inp_sitsot,
...@@ -1457,7 +1465,7 @@ class Scan(PureOp): ...@@ -1457,7 +1465,7 @@ class Scan(PureOp):
info['n_sit_sot'] = n_sitsot_outs info['n_sit_sot'] = n_sitsot_outs
info['n_shared_outs'] = 0 info['n_shared_outs'] = 0
info['n_nit_sot'] = n_nit_sot info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while info['as_while'] = False
info['profile'] = self.profile info['profile'] = self.profile
info['destroy_map'] = {} info['destroy_map'] = {}
if self.name: if self.name:
...@@ -1466,7 +1474,7 @@ class Scan(PureOp): ...@@ -1466,7 +1474,7 @@ class Scan(PureOp):
info['name'] = None info['name'] = None
info['mode'] = self.mode info['mode'] = self.mode
outer_inputs = ([do_steps] + outer_inputs = ([grad_steps] +
outer_inp_seqs + outer_inp_seqs +
outer_inp_mitmot + outer_inp_mitmot +
outer_inp_sitsot + outer_inp_sitsot +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论