提交 83cc7ef3 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5667 from Thrandis/ccw

Added meaningful message when missing inputs to scan.
...@@ -7,7 +7,6 @@ types that it can raise. ...@@ -7,7 +7,6 @@ types that it can raise.
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from collections import OrderedDict from collections import OrderedDict
import time import time
import traceback
import theano import theano
from theano.gof import graph from theano.gof import graph
...@@ -16,7 +15,6 @@ from theano.gof import toolbox ...@@ -16,7 +15,6 @@ from theano.gof import toolbox
from theano import config from theano import config
from six import iteritems, itervalues from six import iteritems, itervalues
from six.moves import StringIO
from theano.gof.utils import get_variable_trace_string from theano.gof.utils import get_variable_trace_string
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
NullType = None NullType = None
...@@ -52,13 +50,9 @@ class MissingInputError(Exception): ...@@ -52,13 +50,9 @@ class MissingInputError(Exception):
if kwargs: if kwargs:
# The call to list is needed for Python 3 # The call to list is needed for Python 3
assert list(kwargs.keys()) == ["variable"] assert list(kwargs.keys()) == ["variable"]
tr = getattr(list(kwargs.values())[0].tag, 'trace', []) error_msg = get_variable_trace_string(kwargs["variable"])
if isinstance(tr, list) and len(tr) > 0: if error_msg:
sio = StringIO() args = args + (error_msg,)
print("\nBacktrace when the variable is created:", file=sio)
for subtr in list(kwargs.values())[0].tag.trace:
traceback.print_list(subtr, sio)
args = args + (str(sio.getvalue()),)
s = '\n'.join(args) # Needed to have the new line print correctly s = '\n'.join(args) # Needed to have the new line print correctly
Exception.__init__(self, s) Exception.__init__(self, s)
...@@ -393,7 +387,6 @@ class FunctionGraph(utils.object2): ...@@ -393,7 +387,6 @@ class FunctionGraph(utils.object2):
"Theano flag exception_verbosity='high', " "Theano flag exception_verbosity='high', "
"for more information on this error." "for more information on this error."
% (node.inputs.index(r), str(node))) % (node.inputs.index(r), str(node)))
error_msg += get_variable_trace_string(r)
raise MissingInputError(error_msg, variable=r) raise MissingInputError(error_msg, variable=r)
for node in new_nodes: for node in new_nodes:
......
...@@ -836,14 +836,21 @@ def scan(fn, ...@@ -836,14 +836,21 @@ def scan(fn,
dummy_outs = outputs dummy_outs = outputs
if condition is not None: if condition is not None:
dummy_outs.append(condition) dummy_outs.append(condition)
dummy_f = function(dummy_args, # Perform a try-except to provide a meaningful error message to the
dummy_outs, # user if inputs of the inner function are missing.
updates=updates, try:
mode=compile.mode.Mode(linker='py', dummy_f = function(dummy_args,
optimizer=None), dummy_outs,
on_unused_input='ignore', updates=updates,
profile=False) mode=compile.mode.Mode(linker='py',
optimizer=None),
on_unused_input='ignore',
profile=False)
except gof.fg.MissingInputError as err:
msg = ("\nPlease pass this variable to the scan's inner function. Do "
"not forget to also pass it to the `non_sequences` attribute "
"of scan.")
raise gof.fg.MissingInputError(err.args[0] + msg)
## ##
# Step 5. Re-arange inputs of scan into a more strict order # Step 5. Re-arange inputs of scan into a more strict order
## ##
......
...@@ -5518,3 +5518,17 @@ class TestInconsistentBroadcast(unittest.TestCase): ...@@ -5518,3 +5518,17 @@ class TestInconsistentBroadcast(unittest.TestCase):
sequences=x, sequences=x,
outputs_info=[dict(initial=initial_x)]) outputs_info=[dict(initial=initial_x)])
gs = tensor.grad(y.sum(), x) gs = tensor.grad(y.sum(), x)
class TestMissingInputError(unittest.TestCase):
@raises(theano.gof.fg.MissingInputError)
def test_raise_error(self):
c = theano.shared(0.)
inc = tensor.scalar('inc')
def count_up():
return tensor.zeros(()), {c: c + inc}
_, updates = theano.scan(count_up, n_steps=20)
func = theano.function(inputs=[inc], outputs=[], updates=updates)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论