提交 a763418b authored 作者: Francesco Visin's avatar Francesco Visin

Reuse theano.gof.op.get_test_value() and fix cuda variables

上级 be939078
......@@ -861,6 +861,13 @@ import theano and print the config variable, as in:
optimization phase. Theano user's do not need to use this. This is
to help debug shape error in Theano optimization.
.. attribute:: print_test_value
Bool value, default: False
If ``'True'``, Theano will override the '__str__' method of its variables
to also print the tag.test_value when this is available.
.. attribute:: reoptimize_unpickled_function
Bool value, default: False (changed in master after Theano 0.7 release)
......
......@@ -645,7 +645,7 @@ AddConfigVar(
AddConfigVar(
'print_test_value_by_default',
'print_test_value',
("If 'True', Theano will override the '__str__' method of its "
"variables to also print the tag.test_value when this is available."),
BoolParam(False),
......
......@@ -391,38 +391,44 @@ class Variable(Node):
self.name = name
self.auto_name = 'auto_' + str(next(self.__count__))
def __str__(self):
"""
WRITEME
def __str__(self, firstPass=True):
"""Return a str representation of the Variable
Return a printable name or description of the Variable. If
config.print_test_value is True it will also print the test_value if
any.
"""
if config.print_test_value_by_default:
if self.__get_test_value__() is not None:
return '\n'.join([self.__str_name__(),
self.__get_test_value__()])
return self.__str_name__()
def __str_name__(self):
"""
WRITEME
to_print = []
if config.print_test_value and firstPass:
try:
to_print.append(self.__str_test_value__())
except AttributeError:
return self.__str__(False)
"""
if self.name is not None:
return self.name
return '\n'.join([self.name] + to_print)
if self.owner is not None:
op = self.owner.op
if self.index == op.default_output:
return str(self.owner.op) + ".out"
return '\n'.join(
[str(self.owner.op) + ".out"] + to_print)
else:
return str(self.owner.op) + "." + str(self.index)
return '\n'.join(
[str(self.owner.op) + "." + str(self.index)] + to_print)
else:
return "<%s>" % str(self.type)
return '\n'.join(["<%s>" % str(self.type)] + to_print)
def __str_test_value__(self):
"""Return a repr of the test value
def __get_test_value__(self):
Return a printable representation of the test value. It can be
overridden by classes with non printable test_value to provide a
suitable representation of the test_value.
"""
try:
return repr(self.tag.test_value)
except AttributeError:
return None
return repr(theano.gof.op.get_test_value(self))
except:
raise
def __repr__(self):
return str(self)
......
......@@ -43,6 +43,13 @@ class _operators(tensor.basic._tensor_py_operators):
class CudaNdarrayVariable(_operators, Variable):
pass
# override default
def __str_test_value__(self):
try:
return repr(numpy.array(theano.gof.op.get_test_value(self)))
except:
raise
CudaNdarrayType.Variable = CudaNdarrayVariable
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论