提交 99fcb5f9 authored 作者: Cesar Laurent's avatar Cesar Laurent

Simple version last step.

上级 7976c911
......@@ -104,18 +104,20 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
# Compute the number of steps of the outer scan
o_n_steps = theano.tensor.cast(theano.tensor.ceil(n_steps / save_every_N),
'int64')
# Compute the number of steps of the inner scan
i_n_steps = save_every_N * theano.tensor.ones((o_n_steps,), 'int64')
mod = n_steps % save_every_N
sign = theano.tensor.sgn(mod)
rest = (1 - sign) * save_every_N + sign * mod
i_n_steps = theano.tensor.set_subtensor(i_n_steps[-1], rest)
last_n_steps = theano.tensor.switch(theano.tensor.eq(mod, 0),
save_every_N, mod)
i_n_steps = theano.tensor.set_subtensor(i_n_steps[-1], last_n_steps)
# Pad the sequences if needed
if padding:
for i, s in enumerate(sequences):
n = s.shape[0] % save_every_N
z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype)
sequences[i] = theano.tensor.concatenate([s, z])
sequences[i] = theano.tensor.concatenate([s, z], axis=0)
# Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论