提交 8ea6f992 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add "print_type" parameter to debugprint, allowing to display r.type

上级 cf2116ba
...@@ -339,37 +339,43 @@ class InvalidValueError(DebugModeError): ...@@ -339,37 +339,43 @@ class InvalidValueError(DebugModeError):
def debugprint(r, prefix='', depth=-1, done=None, file=sys.stdout): def debugprint(r, prefix='', depth=-1, done=None, print_type=False, file=sys.stdout):
"""Print the graph leading to `r` to given depth. """Print the graph leading to `r` to given depth.
:param r: Variable instance :param r: Variable instance
:param prefix: prefix to each line (typically some number of spaces) :param prefix: prefix to each line (typically some number of spaces)
:param depth: maximum recursion depth (Default -1 for unlimited). :param depth: maximum recursion depth (Default -1 for unlimited).
:param done: set of Apply instances that have already been printed :param done: set of Apply instances that have already been printed
:param print_type: wether to print the Variable type after the other infos
:param file: file-like object to which to print :param file: file-like object to which to print
""" """
if depth==0: if depth==0:
return return
#backport
if done is None: if done is None:
done = set() done = set()
#done = set() if done is None else done
if print_type:
type_str = ' <%s>' % r.type
else:
type_str = ''
if hasattr(r.owner, 'op'): if hasattr(r.owner, 'op'):
# this variable is the output of computation, # this variable is the output of computation,
# so just print out the apply # so just print out the apply
a = r.owner a = r.owner
if len(a.outputs) == 1: if len(a.outputs) == 1:
print >> file, '%s%s [@%i]' % (prefix, a.op, id(r)) print >> file, '%s%s [@%i]%s' % (prefix, a.op, id(r), type_str)
else: else:
print >> file, '%s%s.%i [@%i]' % (prefix, a.op, a.outputs.index(r), id(r)) print >> file, '%s%s.%i [@%i]%s' % (prefix, a.op, a.outputs.index(r), id(r), type_str)
if id(a) not in done: if id(a) not in done:
done.add(id(a)) done.add(id(a))
for i in a.inputs: for i in a.inputs:
debugprint(i, prefix+' |', depth=depth-1, done=done, file=file) debugprint(i, prefix+' |', depth=depth-1, done=done, print_type=print_type, file=file)
else: else:
#this is a variable #this is a variable
print >> file, '%s%s [@%i]' % (prefix, r, id(r)) print >> file, '%s%s [@%i]%s' % (prefix, r, id(r), type_str)
return file return file
......
...@@ -11,13 +11,15 @@ from theano.gof.python25 import any ...@@ -11,13 +11,15 @@ from theano.gof.python25 import any
from theano.compile import Function, debugmode from theano.compile import Function, debugmode
from theano.compile.profilemode import ProfileMode from theano.compile.profilemode import ProfileMode
def debugprint(obj, depth=-1, file=None): def debugprint(obj, depth=-1, print_type=False, file=None):
"""Print a computation graph to file """Print a computation graph to file
:type obj: Variable, Apply, or Function instance :type obj: Variable, Apply, or Function instance
:param obj: symbolic thing to print :param obj: symbolic thing to print
:type depth: integer :type depth: integer
:param depth: print graph to this depth (-1 for unlimited) :param depth: print graph to this depth (-1 for unlimited)
:type print_type: boolean
:param print_type: wether to print the type of printed objects
:type file: None, 'str', or file-like object :type file: None, 'str', or file-like object
:param file: print to this file ('str' means to return a string) :param file: print to this file ('str' means to return a string)
...@@ -28,6 +30,7 @@ def debugprint(obj, depth=-1, file=None): ...@@ -28,6 +30,7 @@ def debugprint(obj, depth=-1, file=None):
The first part of the text identifies whether it is an input (if a name or type is printed) The first part of the text identifies whether it is an input (if a name or type is printed)
or the output of some Apply (in which case the Op is printed). or the output of some Apply (in which case the Op is printed).
The second part of the text is the memory location of the Variable. The second part of the text is the memory location of the Variable.
If print_type is True, there is a third part, containing the type of the Variable
If a Variable is encountered multiple times in the depth-first search, it is only printed If a Variable is encountered multiple times in the depth-first search, it is only printed
recursively the first time. Later, just the Variable and its memory location are printed. recursively the first time. Later, just the Variable and its memory location are printed.
...@@ -51,7 +54,7 @@ def debugprint(obj, depth=-1, file=None): ...@@ -51,7 +54,7 @@ def debugprint(obj, depth=-1, file=None):
elif isinstance(obj, Function): elif isinstance(obj, Function):
results_to_print.extend(obj.maker.env.outputs) results_to_print.extend(obj.maker.env.outputs)
for r in results_to_print: for r in results_to_print:
debugmode.debugprint(r, depth=depth, done=done, file=_file) debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, file=_file)
if file is _file: if file is _file:
return file return file
elif file=='str': elif file=='str':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论