提交 968de8cd authored 作者: Francesco Visin's avatar Francesco Visin

Avoid __repr__ recursion

上级 a763418b
...@@ -764,7 +764,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -764,7 +764,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
outer_id_str = get_id_str(outer_r.owner) outer_id_str = get_id_str(outer_r.owner)
else: else:
outer_id_str = get_id_str(outer_r) outer_id_str = get_id_str(outer_r)
print('%s%s %s%s -> %s' % (prefix, r.__str_name__(), id_str, type_str, print('%s%s %s%s -> %s' % (prefix, r, id_str, type_str,
outer_id_str), file=file) outer_id_str), file=file)
else: else:
# this is an input variable # this is an input variable
...@@ -772,7 +772,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -772,7 +772,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
if smap: if smap:
data = " " + str(smap.get(r, '')) data = " " + str(smap.get(r, ''))
id_str = get_id_str(r) id_str = get_id_str(r)
print('%s%s %s%s%s' % (prefix, r.__str_name__(), id_str, print('%s%s %s%s%s' % (prefix, r, id_str,
type_str, data), type_str, data),
file=file) file=file)
......
...@@ -646,8 +646,10 @@ AddConfigVar( ...@@ -646,8 +646,10 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
'print_test_value', 'print_test_value',
("If 'True', Theano will override the '__str__' method of its " ("If 'True', the __eval__ of a Theano variable will return its test_value "
"variables to also print the tag.test_value when this is available."), "when this is available. This has the practical conseguence that, e.g., "
"in debugging `my_var` will print the same as `my_var.tag.test_value` "
"when a test value is defined."),
BoolParam(False), BoolParam(False),
in_c_key=False) in_c_key=False)
......
...@@ -391,47 +391,44 @@ class Variable(Node): ...@@ -391,47 +391,44 @@ class Variable(Node):
self.name = name self.name = name
self.auto_name = 'auto_' + str(next(self.__count__)) self.auto_name = 'auto_' + str(next(self.__count__))
def __str__(self, firstPass=True): def __str__(self):
"""Return a str representation of the Variable """Return a str representation of the Variable.
Return a printable name or description of the Variable. If
config.print_test_value is True it will also print the test_value if
any.
""" """
to_print = []
if config.print_test_value and firstPass:
try:
to_print.append(self.__str_test_value__())
except AttributeError:
return self.__str__(False)
if self.name is not None: if self.name is not None:
return '\n'.join([self.name] + to_print) return self.name
if self.owner is not None: if self.owner is not None:
op = self.owner.op op = self.owner.op
if self.index == op.default_output: if self.index == op.default_output:
return '\n'.join( return str(self.owner.op) + ".out"
[str(self.owner.op) + ".out"] + to_print)
else: else:
return '\n'.join( return str(self.owner.op) + "." + str(self.index)
[str(self.owner.op) + "." + str(self.index)] + to_print)
else: else:
return '\n'.join(["<%s>" % str(self.type)] + to_print) return "<%s>" % str(self.type)
def __str_test_value__(self): def __repr_test_value__(self):
"""Return a repr of the test value """Return a repr of the test value.
Return a printable representation of the test value. It can be Return a printable representation of the test value. It can be
overridden by classes with non printable test_value to provide a overridden by classes with non printable test_value to provide a
suitable representation of the test_value. suitable representation of the test_value.
""" """
try: return repr(theano.gof.op.get_test_value(self))
return repr(theano.gof.op.get_test_value(self))
except:
raise
def __repr__(self): def __repr__(self, firstPass=True):
return str(self) """Return a repr of the Variable.
Return a printable name or description of the Variable. If
config.print_test_value is True it will also print the test_value if
any.
"""
to_print = [str(self)]
if config.print_test_value and firstPass:
try:
to_print.append(self.__repr_test_value__())
except AttributeError:
pass
return '\n'.join(to_print)
def clone(self): def clone(self):
""" """
......
...@@ -45,11 +45,8 @@ class CudaNdarrayVariable(_operators, Variable): ...@@ -45,11 +45,8 @@ class CudaNdarrayVariable(_operators, Variable):
pass pass
# override default # override default
def __str_test_value__(self): def __repr_test_value__(self):
try: return repr(numpy.array(theano.gof.op.get_test_value(self)))
return repr(numpy.array(theano.gof.op.get_test_value(self)))
except:
raise
CudaNdarrayType.Variable = CudaNdarrayVariable CudaNdarrayType.Variable = CudaNdarrayVariable
......
...@@ -420,6 +420,9 @@ class _operators(_tensor_py_operators): ...@@ -420,6 +420,9 @@ class _operators(_tensor_py_operators):
class GpuArrayVariable(_operators, Variable): class GpuArrayVariable(_operators, Variable):
# override the default
def __repr_test_value__(self):
return repr(numpy.array(theano.gof.op.get_test_value(self)))
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论