提交 99fcb5f9 authored 作者: Cesar Laurent's avatar Cesar Laurent

Simple version last step.

上级 7976c911
...@@ -104,18 +104,20 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], ...@@ -104,18 +104,20 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
# Compute the number of steps of the outer scan # Compute the number of steps of the outer scan
o_n_steps = theano.tensor.cast(theano.tensor.ceil(n_steps / save_every_N), o_n_steps = theano.tensor.cast(theano.tensor.ceil(n_steps / save_every_N),
'int64') 'int64')
# Compute the number of steps of the inner scan
i_n_steps = save_every_N * theano.tensor.ones((o_n_steps,), 'int64') i_n_steps = save_every_N * theano.tensor.ones((o_n_steps,), 'int64')
mod = n_steps % save_every_N mod = n_steps % save_every_N
sign = theano.tensor.sgn(mod) last_n_steps = theano.tensor.switch(theano.tensor.eq(mod, 0),
rest = (1 - sign) * save_every_N + sign * mod save_every_N, mod)
i_n_steps = theano.tensor.set_subtensor(i_n_steps[-1], rest) i_n_steps = theano.tensor.set_subtensor(i_n_steps[-1], last_n_steps)
# Pad the sequences if needed # Pad the sequences if needed
if padding: if padding:
for i, s in enumerate(sequences): for i, s in enumerate(sequences):
n = s.shape[0] % save_every_N n = s.shape[0] % save_every_N
z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype) z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype)
sequences[i] = theano.tensor.concatenate([s, z]) sequences[i] = theano.tensor.concatenate([s, z], axis=0)
# Establish the input variables of the outer scan # Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] + o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论