提交 848e1501 authored 作者: Frederic Bastien's avatar Frederic Bastien

Less hacky versino

上级 8f17e281
......@@ -295,7 +295,8 @@ AddConfigVar('traceback.limit',
"The number of stack to trace. -1 mean all.",
# We default to 6 to be able to know where v1 + v2 is created in the
# user script. The bigger this number is, the more run time it takes.
IntParam(6),
# We need to default to 7 to support theano.tensor.tensor(...).
IntParam(7),
in_c_key=False)
AddConfigVar('experimental.mrg',
......
......@@ -307,7 +307,7 @@ class PureType(object):
def make_constant(self, value, name=None):
return self.Constant(type=self, data=value, name=name)
def __call__(self, name=None, limit=None):
def __call__(self, name=None):
"""Return a new `Variable` instance of Type `self`.
:Parameters:
......@@ -315,7 +315,7 @@ class PureType(object):
A pretty string for printing and debugging.
"""
return utils.add_tag_trace(self.make_variable(name), limit=limit)
return utils.add_tag_trace(self.make_variable(name))
def values_eq(self, a, b):
"""
......
......@@ -50,16 +50,18 @@ if sys.version_info[:2] > (3, 4):
simple_extract_stack = traceback.extract_stack
def add_tag_trace(thing, limit=None):
def add_tag_trace(thing, user_line=1):
"""Add tag.trace to an node or variable.
The argument is returned after being affected (inplace).
:param thing: the object where we add .tag.trace
:param limit: The limit of the stack size.
If None use, config.traceback.limit
:param user_line: The max number of user line to keep.
:note: we alse use config.traceback.limit for the maximum number
of stack level we look.
"""
if limit is None:
limit = config.traceback.limit
limit = config.traceback.limit
if limit == -1:
limit = None
tr = simple_extract_stack(limit=limit)[:-1]
......@@ -72,14 +74,21 @@ def add_tag_trace(thing, limit=None):
file_path = tr[-1][0]
rm = False
for p in ["theano/tensor/",
"theano/gof/"]:
"theano/gof/",
"theano/scalar/basic.py",
"theano/sandbox/",
"theano/scan_module/",
"theano/sparse/",
"theano/typed_list/",
]:
if p in file_path:
tr = tr[:-1]
rm = True
break
if not rm:
break
if len(tr) > user_line:
tr = tr[:user_line]
thing.tag.trace = tr
return thing
......
......@@ -744,12 +744,7 @@ def get_scalar_constant_value(orig_v, elemwise=True):
def tensor(*args, **kwargs):
name = kwargs.pop('name', None)
# This add an indirection to the normal call stack. So raise the
# limit to keep the good user line.
limit = config.traceback.limit
if limit != -1:
limit += 1
return TensorType(*args, **kwargs)(name=name, limit=limit)
return TensorType(*args, **kwargs)(name=name)
def _multi(*fns):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论