提交 ccca34c3 authored 作者: carriepl's avatar carriepl

Fix bug in support for In in new theano functions

上级 d5805bac
...@@ -478,7 +478,19 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -478,7 +478,19 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
'theano.clone(f(x), replace={x: g(x)}))`.' 'theano.clone(f(x), replace={x: g(x)}))`.'
% x) % x)
output_vars = rebuild_collect_shared(outputs, # Extend the outputs with the updates on input variables so they are also
# cloned
additional_outputs = [i.update for i in inputs if i.update]
if outputs is None:
out_list = []
else:
if isinstance(outputs, (list, tuple)):
out_list = list(outputs)
else:
out_list = [outputs]
extended_outputs = out_list + additional_outputs
output_vars = rebuild_collect_shared(extended_outputs,
in_variables, in_variables,
replace=givens, replace=givens,
updates=updates, updates=updates,
...@@ -486,12 +498,25 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -486,12 +498,25 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
copy_inputs_over=True, copy_inputs_over=True,
no_default_updates=no_default_updates) no_default_updates=no_default_updates)
# extracting the arguments # extracting the arguments
input_variables, cloned_outputs, other_stuff = output_vars input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff clone_d, update_d, update_expr, shared_inputs = other_stuff
# Recover only the clones of the original outputs
if outputs is None:
cloned_outputs = []
else:
if isinstance(outputs, (list, tuple)):
cloned_outputs = cloned_extended_outputs[:len(outputs)]
else:
cloned_outputs = cloned_extended_outputs[0]
for i, iv in zip(inputs, input_variables): for i, iv in zip(inputs, input_variables):
i.variable = iv i.variable = iv
# If needed, replace the input's update by its cloned equivalent
if i.update:
i.update = clone_d[i.update]
for sv in shared_inputs: for sv in shared_inputs:
# pass value of None # pass value of None
# value will be stored in the resulting functions' defaults # value will be stored in the resulting functions' defaults
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论