提交 a763418b authored 作者: Francesco Visin's avatar Francesco Visin

Reuse theano.gof.op.get_test_value() and fix cuda variables

上级 be939078
...@@ -861,6 +861,13 @@ import theano and print the config variable, as in: ...@@ -861,6 +861,13 @@ import theano and print the config variable, as in:
optimization phase. Theano user's do not need to use this. This is optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization. to help debug shape error in Theano optimization.
.. attribute:: print_test_value
Bool value, default: False
If ``'True'``, Theano will override the '__str__' method of its variables
to also print the tag.test_value when this is available.
.. attribute:: reoptimize_unpickled_function .. attribute:: reoptimize_unpickled_function
Bool value, default: False (changed in master after Theano 0.7 release) Bool value, default: False (changed in master after Theano 0.7 release)
......
...@@ -645,7 +645,7 @@ AddConfigVar( ...@@ -645,7 +645,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
'print_test_value_by_default', 'print_test_value',
("If 'True', Theano will override the '__str__' method of its " ("If 'True', Theano will override the '__str__' method of its "
"variables to also print the tag.test_value when this is available."), "variables to also print the tag.test_value when this is available."),
BoolParam(False), BoolParam(False),
......
...@@ -391,38 +391,44 @@ class Variable(Node): ...@@ -391,38 +391,44 @@ class Variable(Node):
self.name = name self.name = name
self.auto_name = 'auto_' + str(next(self.__count__)) self.auto_name = 'auto_' + str(next(self.__count__))
def __str__(self): def __str__(self, firstPass=True):
""" """Return a str representation of the Variable
WRITEME
Return a printable name or description of the Variable. If
config.print_test_value is True it will also print the test_value if
any.
""" """
if config.print_test_value_by_default: to_print = []
if self.__get_test_value__() is not None: if config.print_test_value and firstPass:
return '\n'.join([self.__str_name__(), try:
self.__get_test_value__()]) to_print.append(self.__str_test_value__())
return self.__str_name__() except AttributeError:
return self.__str__(False)
def __str_name__(self):
"""
WRITEME
"""
if self.name is not None: if self.name is not None:
return self.name return '\n'.join([self.name] + to_print)
if self.owner is not None: if self.owner is not None:
op = self.owner.op op = self.owner.op
if self.index == op.default_output: if self.index == op.default_output:
return str(self.owner.op) + ".out" return '\n'.join(
[str(self.owner.op) + ".out"] + to_print)
else: else:
return str(self.owner.op) + "." + str(self.index) return '\n'.join(
[str(self.owner.op) + "." + str(self.index)] + to_print)
else: else:
return "<%s>" % str(self.type) return '\n'.join(["<%s>" % str(self.type)] + to_print)
def __get_test_value__(self): def __str_test_value__(self):
"""Return a repr of the test value
Return a printable representation of the test value. It can be
overridden by classes with non printable test_value to provide a
suitable representation of the test_value.
"""
try: try:
return repr(self.tag.test_value) return repr(theano.gof.op.get_test_value(self))
except AttributeError: except:
return None raise
def __repr__(self): def __repr__(self):
return str(self) return str(self)
......
...@@ -43,6 +43,13 @@ class _operators(tensor.basic._tensor_py_operators): ...@@ -43,6 +43,13 @@ class _operators(tensor.basic._tensor_py_operators):
class CudaNdarrayVariable(_operators, Variable): class CudaNdarrayVariable(_operators, Variable):
pass pass
# override default
def __str_test_value__(self):
try:
return repr(numpy.array(theano.gof.op.get_test_value(self)))
except:
raise
CudaNdarrayType.Variable = CudaNdarrayVariable CudaNdarrayType.Variable = CudaNdarrayVariable
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论