提交 ce90645e authored 作者: Glexin's avatar Glexin 提交者: Pascal Lamblin

Fix MissingInputError when scan use hidden variable to calculate condition (#6597)

Fix scan MissingInputError cause by hidden var used in condition.
上级 f2a2df15
...@@ -820,6 +820,8 @@ def scan(fn, ...@@ -820,6 +820,8 @@ def scan(fn,
# extract still missing inputs (there still might be so) and add them # extract still missing inputs (there still might be so) and add them
# as non sequences at the end of our args # as non sequences at the end of our args
if condition is not None:
outputs.append(condition)
fake_nonseqs = [x.type() for x in non_seqs] fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = scan_utils.clone(outputs, fake_outputs = scan_utils.clone(outputs,
replace=OrderedDict(izip(non_seqs, replace=OrderedDict(izip(non_seqs,
...@@ -836,8 +838,6 @@ def scan(fn, ...@@ -836,8 +838,6 @@ def scan(fn,
dummy_args += extra_inputs dummy_args += extra_inputs
dummy_outs = outputs dummy_outs = outputs
if condition is not None:
dummy_outs.append(condition)
# Perform a try-except to provide a meaningful error message to the # Perform a try-except to provide a meaningful error message to the
# user if inputs of the inner function are missing. # user if inputs of the inner function are missing.
try: try:
......
...@@ -5718,3 +5718,23 @@ class TestGradUntil(unittest.TestCase): ...@@ -5718,3 +5718,23 @@ class TestGradUntil(unittest.TestCase):
numpy_grad = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0]) numpy_grad = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
numpy_grad = numpy_grad.astype(theano.config.floatX) numpy_grad = numpy_grad.astype(theano.config.floatX)
utt.assert_allclose(theano_gradient, numpy_grad) utt.assert_allclose(theano_gradient, numpy_grad)
def test_condition_hidden_inp():
max_value = theano.tensor.scalar("max_value")
n_steps = theano.tensor.iscalar("n_steps")
def accum(prev_value, step):
new_value = prev_value + step
new_step = step + 1
condition = theano.scan_module.until(new_value > max_value)
return [new_value, new_step], condition
rs, updates = theano.scan(
fn=accum,
outputs_info=[0, 0],
n_steps=n_steps)
f = theano.function(
inputs=[max_value, n_steps],
outputs=rs)
_sum, total_steps = f(100, 100)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论