提交 0793fd51 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Improved scan error message

Two improvements: - An explicit error is raised if the lambda expression used in scan returns something that is not made of Theano variables (which may happen for instance when someone returns a constant value). - An error is always raised when failing to parse the return value of the lambda expression
上级 9691e746
...@@ -188,19 +188,47 @@ def get_updates_and_outputs(ls): ...@@ -188,19 +188,47 @@ def get_updates_and_outputs(ls):
else: else:
return [x] return [x]
def filter(x):
"""
Ensure `x` is made only of allowed data types.
Return True iff `x` is made only of lists, tuples, dictionaries, Theano
variables or `theano.scan_module.until` objects.
"""
# Is `x` a container we can iterate on?
iter_on = None
if isinstance(x, list) or isinstance(x, tuple):
iter_on = x
elif isinstance(x, dict):
iter_on = x.iteritems()
if iter_on is not None:
return all(filter(y) for y in iter_on)
else:
return (isinstance(x, theano.Variable) or
isinstance(x, theano.scan_module.until))
if not filter(ls):
raise ValueError(
'The return value of your scan lambda expression may only be '
'made of lists, tuples, or dictionaries containing Theano '
'variables (or `theano.scan_module.until` objects for '
'conditions). In particular if you need to use constant values, '
'you can use `tensor.constant` to turn them into Theano '
'variables.')
if is_outputs(ls): if is_outputs(ls):
return None, _list(ls), {} return None, _list(ls), {}
if is_updates(ls): if is_updates(ls):
return None, [], dict(ls) return None, [], dict(ls)
error_msg = ('Scan cannot parse the return value of your lambda '
'expression, which is: %s' % (ls,))
if not isinstance(ls, (list, tuple)): if not isinstance(ls, (list, tuple)):
raise ValueError(('Scan can not parse the return value' raise ValueError(error_msg)
' of your lambda expression'))
ls = list(ls) ls = list(ls)
deprication_msg = ('The return value of the lambda function' deprication_msg = ('The return value of the lambda function'
' has been restricted. you have to always return first the' ' has been restricted. you have to always return first the'
' outputs (if any), afterwards the updates (if any) and' ' outputs (if any), afterwards the updates (if any) and'
' at the end the conclusion') ' at the end the conclusion')
error_msg = 'Scan can not parse the return value of your lambda expression'
if len(ls) == 2: if len(ls) == 2:
if is_outputs(ls[0]): if is_outputs(ls[0]):
if is_updates(ls[1]): if is_updates(ls[1]):
...@@ -229,6 +257,8 @@ def get_updates_and_outputs(ls): ...@@ -229,6 +257,8 @@ def get_updates_and_outputs(ls):
raise ValueError(error_msg) raise ValueError(error_msg)
else: else:
raise ValueError(error_msg) raise ValueError(error_msg)
else:
raise ValueError(error_msg)
def isNaN_or_Inf_or_None(x): def isNaN_or_Inf_or_None(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论