提交 888485ef authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #377 from delallea/better_scan_error

Improved scan error message I think that we can't really use iter(x) as suggested by David, because we want to iterate over the items.
......@@ -188,19 +188,47 @@ def get_updates_and_outputs(ls):
else:
return [x]
def filter(x):
"""
Ensure `x` is made only of allowed data types.
Return True iff `x` is made only of lists, tuples, dictionaries, Theano
variables or `theano.scan_module.until` objects.
"""
# Is `x` a container we can iterate on?
iter_on = None
if isinstance(x, list) or isinstance(x, tuple):
iter_on = x
elif isinstance(x, dict):
iter_on = x.iteritems()
if iter_on is not None:
return all(filter(y) for y in iter_on)
else:
return (isinstance(x, theano.Variable) or
isinstance(x, theano.scan_module.until))
if not filter(ls):
raise ValueError(
'The return value of your scan lambda expression may only be '
'made of lists, tuples, or dictionaries containing Theano '
'variables (or `theano.scan_module.until` objects for '
'conditions). In particular if you need to use constant values, '
'you can use `tensor.constant` to turn them into Theano '
'variables.')
if is_outputs(ls):
return None, _list(ls), {}
if is_updates(ls):
return None, [], dict(ls)
error_msg = ('Scan cannot parse the return value of your lambda '
'expression, which is: %s' % (ls,))
if not isinstance(ls, (list, tuple)):
raise ValueError(('Scan can not parse the return value'
' of your lambda expression'))
raise ValueError(error_msg)
ls = list(ls)
deprication_msg = ('The return value of the lambda function'
' has been restricted. you have to always return first the'
' outputs (if any), afterwards the updates (if any) and'
' at the end the conclusion')
error_msg = 'Scan can not parse the return value of your lambda expression'
if len(ls) == 2:
if is_outputs(ls[0]):
if is_updates(ls[1]):
......@@ -229,6 +257,8 @@ def get_updates_and_outputs(ls):
raise ValueError(error_msg)
else:
raise ValueError(error_msg)
else:
raise ValueError(error_msg)
def isNaN_or_Inf_or_None(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论