提交 cca20baf authored 作者: abergeron's avatar abergeron

Merge pull request #3605 from fvisin/tag.test_value

Add print_test_value_by_default option
......@@ -861,6 +861,13 @@ import theano and print the config variable, as in:
optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization.
.. attribute:: print_test_value
Bool value, default: False
If ``'True'``, Theano will override the '__str__' method of its variables
to also print the tag.test_value when this is available.
.. attribute:: reoptimize_unpickled_function
Bool value, default: False (changed in master after Theano 0.7 release)
......
......@@ -644,6 +644,16 @@ AddConfigVar(
in_c_key=False)
AddConfigVar(
'print_test_value',
("If 'True', the __eval__ of a Theano variable will return its test_value "
"when this is available. This has the practical conseguence that, e.g., "
"in debugging `my_var` will print the same as `my_var.tag.test_value` "
"when a test value is defined."),
BoolParam(False),
in_c_key=False)
AddConfigVar('compute_test_value_opt',
("For debugging Theano optimization only."
" Same as compute_test_value, but is used"
......
......@@ -12,6 +12,7 @@ from copy import copy
from itertools import count
import theano
from theano import config
from theano.gof import utils
from six import string_types, integer_types, iteritems
from theano.misc.ordered_set import OrderedSet
......@@ -391,8 +392,7 @@ class Variable(Node):
self.auto_name = 'auto_' + str(next(self.__count__))
def __str__(self):
"""
WRITEME
"""Return a str representation of the Variable.
"""
if self.name is not None:
......@@ -406,8 +406,29 @@ class Variable(Node):
else:
return "<%s>" % str(self.type)
def __repr__(self):
return str(self)
def __repr_test_value__(self):
"""Return a repr of the test value.
Return a printable representation of the test value. It can be
overridden by classes with non printable test_value to provide a
suitable representation of the test_value.
"""
return repr(theano.gof.op.get_test_value(self))
def __repr__(self, firstPass=True):
"""Return a repr of the Variable.
Return a printable name or description of the Variable. If
config.print_test_value is True it will also print the test_value if
any.
"""
to_print = [str(self)]
if config.print_test_value and firstPass:
try:
to_print.append(self.__repr_test_value__())
except AttributeError:
pass
return '\n'.join(to_print)
def clone(self):
"""
......
......@@ -43,6 +43,10 @@ class _operators(tensor.basic._tensor_py_operators):
class CudaNdarrayVariable(_operators, Variable):
pass
# override default
def __repr_test_value__(self):
return repr(numpy.array(theano.gof.op.get_test_value(self)))
CudaNdarrayType.Variable = CudaNdarrayVariable
......
......@@ -420,6 +420,9 @@ class _operators(_tensor_py_operators):
class GpuArrayVariable(_operators, Variable):
# override the default
def __repr_test_value__(self):
return repr(numpy.array(theano.gof.op.get_test_value(self)))
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论