提交 98a120b4 authored 作者: Reyhane Askari's avatar Reyhane Askari

changed concatenate to join to use view

上级 519ac73c
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import theano import theano
from theano.tensor.basic import Join
def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
...@@ -114,10 +115,12 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], ...@@ -114,10 +115,12 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
# Pad the sequences if needed # Pad the sequences if needed
if padding: if padding:
# Since padding could be an empty tensor, Join returns a view of s.
join = Join(view=0)
for i, s in enumerate(sequences): for i, s in enumerate(sequences):
n = s.shape[0] % save_every_N n = s.shape[0] % save_every_N
z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype) z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype)
sequences[i] = theano.tensor.concatenate([s, z], axis=0) sequences[i] = join(0, [s, z])
# Establish the input variables of the outer scan # Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] + o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论