提交 848e1501 authored 作者: Frederic Bastien's avatar Frederic Bastien

Less hacky versino

上级 8f17e281
...@@ -295,7 +295,8 @@ AddConfigVar('traceback.limit', ...@@ -295,7 +295,8 @@ AddConfigVar('traceback.limit',
"The number of stack to trace. -1 mean all.", "The number of stack to trace. -1 mean all.",
# We default to 6 to be able to know where v1 + v2 is created in the # We default to 6 to be able to know where v1 + v2 is created in the
# user script. The bigger this number is, the more run time it takes. # user script. The bigger this number is, the more run time it takes.
IntParam(6), # We need to default to 7 to support theano.tensor.tensor(...).
IntParam(7),
in_c_key=False) in_c_key=False)
AddConfigVar('experimental.mrg', AddConfigVar('experimental.mrg',
......
...@@ -307,7 +307,7 @@ class PureType(object): ...@@ -307,7 +307,7 @@ class PureType(object):
def make_constant(self, value, name=None): def make_constant(self, value, name=None):
return self.Constant(type=self, data=value, name=name) return self.Constant(type=self, data=value, name=name)
def __call__(self, name=None, limit=None): def __call__(self, name=None):
"""Return a new `Variable` instance of Type `self`. """Return a new `Variable` instance of Type `self`.
:Parameters: :Parameters:
...@@ -315,7 +315,7 @@ class PureType(object): ...@@ -315,7 +315,7 @@ class PureType(object):
A pretty string for printing and debugging. A pretty string for printing and debugging.
""" """
return utils.add_tag_trace(self.make_variable(name), limit=limit) return utils.add_tag_trace(self.make_variable(name))
def values_eq(self, a, b): def values_eq(self, a, b):
""" """
......
...@@ -50,15 +50,17 @@ if sys.version_info[:2] > (3, 4): ...@@ -50,15 +50,17 @@ if sys.version_info[:2] > (3, 4):
simple_extract_stack = traceback.extract_stack simple_extract_stack = traceback.extract_stack
def add_tag_trace(thing, limit=None): def add_tag_trace(thing, user_line=1):
"""Add tag.trace to an node or variable. """Add tag.trace to an node or variable.
The argument is returned after being affected (inplace). The argument is returned after being affected (inplace).
:param thing: the object where we add .tag.trace :param thing: the object where we add .tag.trace
:param limit: The limit of the stack size. :param user_line: The max number of user line to keep.
If None use, config.traceback.limit
:note: we alse use config.traceback.limit for the maximum number
of stack level we look.
""" """
if limit is None:
limit = config.traceback.limit limit = config.traceback.limit
if limit == -1: if limit == -1:
limit = None limit = None
...@@ -72,14 +74,21 @@ def add_tag_trace(thing, limit=None): ...@@ -72,14 +74,21 @@ def add_tag_trace(thing, limit=None):
file_path = tr[-1][0] file_path = tr[-1][0]
rm = False rm = False
for p in ["theano/tensor/", for p in ["theano/tensor/",
"theano/gof/"]: "theano/gof/",
"theano/scalar/basic.py",
"theano/sandbox/",
"theano/scan_module/",
"theano/sparse/",
"theano/typed_list/",
]:
if p in file_path: if p in file_path:
tr = tr[:-1] tr = tr[:-1]
rm = True rm = True
break break
if not rm: if not rm:
break break
if len(tr) > user_line:
tr = tr[:user_line]
thing.tag.trace = tr thing.tag.trace = tr
return thing return thing
......
...@@ -744,12 +744,7 @@ def get_scalar_constant_value(orig_v, elemwise=True): ...@@ -744,12 +744,7 @@ def get_scalar_constant_value(orig_v, elemwise=True):
def tensor(*args, **kwargs): def tensor(*args, **kwargs):
name = kwargs.pop('name', None) name = kwargs.pop('name', None)
# This add an indirection to the normal call stack. So raise the return TensorType(*args, **kwargs)(name=name)
# limit to keep the good user line.
limit = config.traceback.limit
if limit != -1:
limit += 1
return TensorType(*args, **kwargs)(name=name, limit=limit)
def _multi(*fns): def _multi(*fns):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论