提交 8fc19902 authored 作者: abergeron's avatar abergeron

Merge pull request #2504 from nouiz/err_msg

Better error message when the test_value do not have the right type
...@@ -295,7 +295,8 @@ AddConfigVar('traceback.limit', ...@@ -295,7 +295,8 @@ AddConfigVar('traceback.limit',
"The number of stack to trace. -1 mean all.", "The number of stack to trace. -1 mean all.",
# We default to 6 to be able to know where v1 + v2 is created in the # We default to 6 to be able to know where v1 + v2 is created in the
# user script. The bigger this number is, the more run time it takes. # user script. The bigger this number is, the more run time it takes.
IntParam(6), # We need to default to 7 to support theano.tensor.tensor(...).
IntParam(7),
in_c_key=False) in_c_key=False)
AddConfigVar('experimental.mrg', AddConfigVar('experimental.mrg',
......
...@@ -16,8 +16,10 @@ import inspect ...@@ -16,8 +16,10 @@ import inspect
import logging import logging
import numpy import numpy
import os import os
import sys
import re import re
import StringIO
import sys
import traceback
import warnings import warnings
import theano import theano
...@@ -448,7 +450,31 @@ class PureOp(object): ...@@ -448,7 +450,31 @@ class PureOp(object):
return v.get_value(borrow=True, return_internal_type=True) return v.get_value(borrow=True, return_internal_type=True)
elif isinstance(v, graph.Variable) and hasattr(v.tag, 'test_value'): elif isinstance(v, graph.Variable) and hasattr(v.tag, 'test_value'):
# ensure that the test value is correct # ensure that the test value is correct
return v.type.filter(v.tag.test_value) try:
ret = v.type.filter(v.tag.test_value)
except Exception, e:
# Better error message.
detailed_err_msg = (
"For compute_test_value, one input test value does not"
" have the requested type.\n")
tr = getattr(v.tag, 'trace', None)
if tr:
sio = StringIO.StringIO()
traceback.print_list(tr, sio)
tr = sio.getvalue()
detailed_err_msg += (
" \nBacktrace when that variable is created:\n")
detailed_err_msg += str(tr)
detailed_err_msg += (
"\nThe error when converting the test value to that"
" variable type:")
# We need to only have 1 args and it should be of type
# string. Otherwise, it print the tuple and so the
# new line do not get printed.
args = (detailed_err_msg,) + tuple(str(arg) for arg in e.args)
e.args = ("\n".join(args),)
raise
return ret
raise AttributeError('%s has no test value' % v) raise AttributeError('%s has no test value' % v)
......
...@@ -307,8 +307,7 @@ class PureType(object): ...@@ -307,8 +307,7 @@ class PureType(object):
def make_constant(self, value, name=None): def make_constant(self, value, name=None):
return self.Constant(type=self, data=value, name=name) return self.Constant(type=self, data=value, name=name)
def __call__(self, name=None):
def __call__(self, name = None):
"""Return a new `Variable` instance of Type `self`. """Return a new `Variable` instance of Type `self`.
:Parameters: :Parameters:
......
...@@ -50,10 +50,16 @@ if sys.version_info[:2] > (3, 4): ...@@ -50,10 +50,16 @@ if sys.version_info[:2] > (3, 4):
simple_extract_stack = traceback.extract_stack simple_extract_stack = traceback.extract_stack
def add_tag_trace(thing): def add_tag_trace(thing, user_line=1):
"""Add tag.trace to an node or variable. """Add tag.trace to an node or variable.
The argument is returned after being affected (inplace). The argument is returned after being affected (inplace).
:param thing: the object where we add .tag.trace
:param user_line: The max number of user line to keep.
:note: we alse use config.traceback.limit for the maximum number
of stack level we look.
""" """
limit = config.traceback.limit limit = config.traceback.limit
if limit == -1: if limit == -1:
...@@ -68,14 +74,21 @@ def add_tag_trace(thing): ...@@ -68,14 +74,21 @@ def add_tag_trace(thing):
file_path = tr[-1][0] file_path = tr[-1][0]
rm = False rm = False
for p in ["theano/tensor/", for p in ["theano/tensor/",
"theano/gof/"]: "theano/gof/",
"theano/scalar/basic.py",
"theano/sandbox/",
"theano/scan_module/",
"theano/sparse/",
"theano/typed_list/",
]:
if p in file_path: if p in file_path:
tr = tr[:-1] tr = tr[:-1]
rm = True rm = True
break break
if not rm: if not rm:
break break
if len(tr) > user_line:
tr = tr[:user_line]
thing.tag.trace = tr thing.tag.trace = tr
return thing return thing
......
...@@ -744,7 +744,7 @@ def get_scalar_constant_value(orig_v, elemwise=True): ...@@ -744,7 +744,7 @@ def get_scalar_constant_value(orig_v, elemwise=True):
def tensor(*args, **kwargs): def tensor(*args, **kwargs):
name = kwargs.pop('name', None) name = kwargs.pop('name', None)
return TensorType(*args, **kwargs).make_variable(name=name) return TensorType(*args, **kwargs)(name=name)
def _multi(*fns): def _multi(*fns):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论