提交 cca20baf authored 作者: abergeron's avatar abergeron

Merge pull request #3605 from fvisin/tag.test_value

Add print_test_value_by_default option
...@@ -861,6 +861,13 @@ import theano and print the config variable, as in: ...@@ -861,6 +861,13 @@ import theano and print the config variable, as in:
optimization phase. Theano user's do not need to use this. This is optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization. to help debug shape error in Theano optimization.
.. attribute:: print_test_value
Bool value, default: False
If ``'True'``, Theano will override the '__str__' method of its variables
to also print the tag.test_value when this is available.
.. attribute:: reoptimize_unpickled_function .. attribute:: reoptimize_unpickled_function
Bool value, default: False (changed in master after Theano 0.7 release) Bool value, default: False (changed in master after Theano 0.7 release)
......
...@@ -644,6 +644,16 @@ AddConfigVar( ...@@ -644,6 +644,16 @@ AddConfigVar(
in_c_key=False) in_c_key=False)
AddConfigVar(
'print_test_value',
("If 'True', the __eval__ of a Theano variable will return its test_value "
"when this is available. This has the practical conseguence that, e.g., "
"in debugging `my_var` will print the same as `my_var.tag.test_value` "
"when a test value is defined."),
BoolParam(False),
in_c_key=False)
AddConfigVar('compute_test_value_opt', AddConfigVar('compute_test_value_opt',
("For debugging Theano optimization only." ("For debugging Theano optimization only."
" Same as compute_test_value, but is used" " Same as compute_test_value, but is used"
......
...@@ -12,6 +12,7 @@ from copy import copy ...@@ -12,6 +12,7 @@ from copy import copy
from itertools import count from itertools import count
import theano import theano
from theano import config
from theano.gof import utils from theano.gof import utils
from six import string_types, integer_types, iteritems from six import string_types, integer_types, iteritems
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
...@@ -391,8 +392,7 @@ class Variable(Node): ...@@ -391,8 +392,7 @@ class Variable(Node):
self.auto_name = 'auto_' + str(next(self.__count__)) self.auto_name = 'auto_' + str(next(self.__count__))
def __str__(self): def __str__(self):
""" """Return a str representation of the Variable.
WRITEME
""" """
if self.name is not None: if self.name is not None:
...@@ -406,8 +406,29 @@ class Variable(Node): ...@@ -406,8 +406,29 @@ class Variable(Node):
else: else:
return "<%s>" % str(self.type) return "<%s>" % str(self.type)
def __repr__(self): def __repr_test_value__(self):
return str(self) """Return a repr of the test value.
Return a printable representation of the test value. It can be
overridden by classes with non printable test_value to provide a
suitable representation of the test_value.
"""
return repr(theano.gof.op.get_test_value(self))
def __repr__(self, firstPass=True):
"""Return a repr of the Variable.
Return a printable name or description of the Variable. If
config.print_test_value is True it will also print the test_value if
any.
"""
to_print = [str(self)]
if config.print_test_value and firstPass:
try:
to_print.append(self.__repr_test_value__())
except AttributeError:
pass
return '\n'.join(to_print)
def clone(self): def clone(self):
""" """
......
...@@ -43,6 +43,10 @@ class _operators(tensor.basic._tensor_py_operators): ...@@ -43,6 +43,10 @@ class _operators(tensor.basic._tensor_py_operators):
class CudaNdarrayVariable(_operators, Variable): class CudaNdarrayVariable(_operators, Variable):
pass pass
# override default
def __repr_test_value__(self):
return repr(numpy.array(theano.gof.op.get_test_value(self)))
CudaNdarrayType.Variable = CudaNdarrayVariable CudaNdarrayType.Variable = CudaNdarrayVariable
......
...@@ -420,6 +420,9 @@ class _operators(_tensor_py_operators): ...@@ -420,6 +420,9 @@ class _operators(_tensor_py_operators):
class GpuArrayVariable(_operators, Variable): class GpuArrayVariable(_operators, Variable):
# override the default
def __repr_test_value__(self):
return repr(numpy.array(theano.gof.op.get_test_value(self)))
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论