提交 e2b5805f authored 作者: Frederic Bastien's avatar Frederic Bastien

backport to python 2.4

上级 48affe9c
...@@ -320,14 +320,20 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -320,14 +320,20 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
for i,init_out in enumerate(outs_info): for i,init_out in enumerate(outs_info):
if init_out.get('taps', None) == [-1]: if init_out.get('taps', None) == [-1]:
args += [init_out['initial'].type()] args += [init_out['initial'].type()]
val = slice_to_seqs[-1] if slice_to_seqs else -1 if slice_to_seqs:
val = slice_to_seqs[-1]
else:
val = -1
slice_to_seqs += [ val+1 ] slice_to_seqs += [ val+1 ]
dummy_notshared_init_outs += 1 dummy_notshared_init_outs += 1
elif init_out.get('taps',None): elif init_out.get('taps',None):
if numpy.any(numpy.array(init_out.get('taps',[])) > 0): if numpy.any(numpy.array(init_out.get('taps',[])) > 0):
raise ValueError('Can not use future taps of outputs', init_out) raise ValueError('Can not use future taps of outputs', init_out)
slices = [ init_out['initial'][0].type() for k in init_out['taps'] ] slices = [ init_out['initial'][0].type() for k in init_out['taps'] ]
val = slice_to_seqs[-1] if slice_to_seqs else -1 if slice_to_seqs:
val = slice_to_seqs[-1]
else:
val = -1
slice_to_seqs += [ val+1 for k in init_out['taps'] ] slice_to_seqs += [ val+1 for k in init_out['taps'] ]
args += slices args += slices
dummy_notshared_init_outs += len(init_out['taps']) dummy_notshared_init_outs += len(init_out['taps'])
...@@ -423,7 +429,9 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -423,7 +429,9 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
if isinstance(input.variable, theano.compile.SharedVariable) and input.update: if isinstance(input.variable, theano.compile.SharedVariable) and input.update:
new_var = input.variable.type() new_var = input.variable.type()
inner_fn_inputs.append(new_var) inner_fn_inputs.append(new_var)
val = slice_to_seqs[-1] if slice_to_seqs else -1 if slice_to_seqs:
val = slice_to_seqs[-1]
else: val = -1
slice_to_seqs += [ val+1 ] slice_to_seqs += [ val+1 ]
inner_fn_out_states += [input.update] inner_fn_out_states += [input.update]
update_map[ input.variable ] = n_extended_outs update_map[ input.variable ] = n_extended_outs
...@@ -438,7 +446,9 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -438,7 +446,9 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
if isinstance(input.variable, theano.compile.SharedVariable) and not input.update: if isinstance(input.variable, theano.compile.SharedVariable) and not input.update:
shared_non_seqs += [input.variable] shared_non_seqs += [input.variable]
inner_fn_inputs += [input.variable.type() ] inner_fn_inputs += [input.variable.type() ]
val = slice_to_seqs[-1] if slice_to_seqs else -1 if slice_to_seqs:
val = slice_to_seqs[-1]
else: val = -1
slice_to_seqs += [val +1] slice_to_seqs += [val +1]
givens[input.variable] = inner_fn_inputs[-1] givens[input.variable] = inner_fn_inputs[-1]
elif not isinstance(input.variable, theano.compile.SharedVariable): elif not isinstance(input.variable, theano.compile.SharedVariable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论