提交 7ed08878 authored 作者: James Bergstra's avatar James Bergstra

Added option for debugprint to return a string rather than printing it.

This was done so that the string could be passed to the logging module.
上级 3def9343
"""Pretty-printing (pprint()), the 'Print' Op, debugprint() and pydotprint(). """Pretty-printing (pprint()), the 'Print' Op, debugprint() and pydotprint().
They all allow different way to print a graph or the result of an Op in a graph(Print Op) They all allow different way to print a graph or the result of an Op in a graph(Print Op)
""" """
import gof import sys, os, StringIO
from copy import copy from copy import copy
import sys,os
import gof
from theano import config from theano import config
from gof import Op, Apply from gof import Op, Apply
from theano.gof.python25 import any from theano.gof.python25 import any
...@@ -17,11 +18,10 @@ def debugprint(obj, depth=-1, file=None): ...@@ -17,11 +18,10 @@ def debugprint(obj, depth=-1, file=None):
:param obj: symbolic thing to print :param obj: symbolic thing to print
:type depth: integer :type depth: integer
:param depth: print graph to this depth (-1 for unlimited) :param depth: print graph to this depth (-1 for unlimited)
:type file: None or file-like object :type file: None, 'str', or file-like object
:param file: print to this file (None means sys.stdout) :param file: print to this file ('str' means to return a string)
:rtype: None or file-like object :returns: str if `file`=='str', else file arg
:returns: `file` argument
Each line printed represents a Variable in the graph. Each line printed represents a Variable in the graph.
The indentation of each line corresponds to its depth in the symbolic graph. The indentation of each line corresponds to its depth in the symbolic graph.
...@@ -36,7 +36,9 @@ def debugprint(obj, depth=-1, file=None): ...@@ -36,7 +36,9 @@ def debugprint(obj, depth=-1, file=None):
identifier, to indicate which output a line corresponds to. identifier, to indicate which output a line corresponds to.
""" """
if file is None: if file == 'str':
_file = StringIO.StringIO()
elif file is None:
_file = sys.stdout _file = sys.stdout
else: else:
_file = file _file = file
...@@ -50,9 +52,12 @@ def debugprint(obj, depth=-1, file=None): ...@@ -50,9 +52,12 @@ def debugprint(obj, depth=-1, file=None):
results_to_print.extend(obj.maker.env.outputs) results_to_print.extend(obj.maker.env.outputs)
for r in results_to_print: for r in results_to_print:
debugmode.debugprint(r, depth=depth, done=done, file=_file) debugmode.debugprint(r, depth=depth, done=done, file=_file)
if file is None: if file is _file:
_file.flush()
return file return file
elif file=='str':
return _file.getvalue()
else:
_file.flush()
class Print(Op): class Print(Op):
"""This identity-like Op has the side effect of printing a message followed by its inputs """This identity-like Op has the side effect of printing a message followed by its inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论