提交 a5331484 authored 作者: Frederic's avatar Frederic

Allow debugprint to use int, char or remove the id of each printed line.

上级 9ca0033f
...@@ -480,32 +480,73 @@ class InvalidValueError(DebugModeError): ...@@ -480,32 +480,73 @@ class InvalidValueError(DebugModeError):
######################## ########################
def char_from_number(number):
""" Converts number to string by rendering it in base 26 using
capital letters as digits """
base = 26
rval = ""
if number == 0:
rval = 'A'
while number != 0:
remainder = number % base
new_char = chr(ord('A') + remainder)
rval = new_char + rval
number /= base
return rval
def debugprint(r, prefix='', depth=-1, done=None, print_type=False, def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
file=sys.stdout, print_destroy_map=False, print_view_map=False, file=sys.stdout, print_destroy_map=False, print_view_map=False,
order=[]): order=[], ids='id'):
"""Print the graph leading to `r` to given depth. """Print the graph leading to `r` to given depth.
:param r: Variable instance :param r: Variable instance
:param prefix: prefix to each line (typically some number of spaces) :param prefix: prefix to each line (typically some number of spaces)
:param depth: maximum recursion depth (Default -1 for unlimited). :param depth: maximum recursion depth (Default -1 for unlimited).
:param done: set of Apply instances that have already been printed :param done: dict of Apply instances that have already been printed
and there associated printed ids
:param print_type: wether to print the Variable type after the other infos :param print_type: wether to print the Variable type after the other infos
:param file: file-like object to which to print :param file: file-like object to which to print
:param print_destroy_map: wether to print the op destroy_map after ofther info :param print_destroy_map: wether to print the op destroy_map after ofther info
:param print_view_map: wether to print the op view_map after ofther info :param print_view_map: wether to print the op view_map after ofther info
:param order: If not empty will print the index in the toposort. :param order: If not empty will print the index in the toposort.
:param ids: How do we print the identifier of the variable
id - print the python id value
int - print integer character
CHAR - print capital character
"" - don't print an identifier
""" """
if depth == 0: if depth == 0:
return return
if done is None: if done is None:
done = set() done = dict()
if print_type: if print_type:
type_str = ' <%s>' % r.type type_str = ' <%s>' % r.type
else: else:
type_str = '' type_str = ''
def get_id_str(obj):
if obj in done:
id_str = "[@%s]" % done[obj]
elif ids == "id":
id_str = "[@%s]" % str(id(r))
elif ids == "int":
id_str = "[@%s]" % str(len(done))
elif ids == "CHAR":
id_str = "[@%s]" % char_from_number(len(done))
elif ids == "":
id_str = ""
done[obj] = id_str
return id_str
if hasattr(r.owner, 'op'): if hasattr(r.owner, 'op'):
# this variable is the output of computation, # this variable is the output of computation,
# so just print out the apply # so just print out the apply
...@@ -534,29 +575,33 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -534,29 +575,33 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
o = '' o = ''
if order: if order:
o = str(order.index(r.owner)) o = str(order.index(r.owner))
already_printed = a in done # get_id_str put it in the dict
id_str = get_id_str(a)
if len(a.outputs) == 1: if len(a.outputs) == 1:
print >> file, '%s%s [@%i]%s \'%s\' %s %s %s' % (prefix, a.op, print >> file, '%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op,
id(r), id_str,
type_str, r_name, type_str, r_name,
destroy_map_str, destroy_map_str,
view_map_str, view_map_str,
o) o)
else: else:
print >> file, '%s%s.%i [@%i]%s \'%s\' %s %s %s' % (prefix, a.op, print >> file, '%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op,
a.outputs.index(r), a.outputs.index(r),
id(r), type_str, id_str, type_str,
r_name, r_name,
destroy_map_str, destroy_map_str,
view_map_str, view_map_str,
o) o)
if id(a) not in done: if not already_printed:
done.add(id(a))
for i in a.inputs: for i in a.inputs:
debugprint(i, prefix + ' |', depth=depth-1, done=done, debugprint(i, prefix + ' |', depth=depth - 1, done=done,
print_type=print_type, file=file, order=order) print_type=print_type, file=file, order=order,
ids=ids)
else: else:
#this is a variable #this is an input variable
print >> file, '%s%s [@%i]%s' % (prefix, r, id(r), type_str) id_str = get_id_str(r)
print >> file, '%s%s %s%s' % (prefix, r, id_str, type_str)
return file return file
......
...@@ -28,7 +28,8 @@ from theano.compile.profilemode import ProfileMode ...@@ -28,7 +28,8 @@ from theano.compile.profilemode import ProfileMode
_logger = logging.getLogger("theano.printing") _logger = logging.getLogger("theano.printing")
def debugprint(obj, depth=-1, print_type=False, file=None): def debugprint(obj, depth=-1, print_type=False,
file=None, ids='id'):
"""Print a computation graph to file """Print a computation graph to file
:type obj: Variable, Apply, or Function instance :type obj: Variable, Apply, or Function instance
...@@ -39,6 +40,12 @@ def debugprint(obj, depth=-1, print_type=False, file=None): ...@@ -39,6 +40,12 @@ def debugprint(obj, depth=-1, print_type=False, file=None):
:param print_type: wether to print the type of printed objects :param print_type: wether to print the type of printed objects
:type file: None, 'str', or file-like object :type file: None, 'str', or file-like object
:param file: print to this file ('str' means to return a string) :param file: print to this file ('str' means to return a string)
:type ids: str
:param ids: How do we print the identifier of the variable
id - print the python id value
CHAR - print capital character
CHAR - print capital character
"" - don't print an identifier
:returns: string if `file` == 'str', else file arg :returns: string if `file` == 'str', else file arg
...@@ -64,7 +71,7 @@ def debugprint(obj, depth=-1, print_type=False, file=None): ...@@ -64,7 +71,7 @@ def debugprint(obj, depth=-1, print_type=False, file=None):
_file = sys.stdout _file = sys.stdout
else: else:
_file = file _file = file
done = set() done = dict()
results_to_print = [] results_to_print = []
order = [] order = []
if isinstance(obj, gof.Variable): if isinstance(obj, gof.Variable):
...@@ -83,7 +90,7 @@ def debugprint(obj, depth=-1, print_type=False, file=None): ...@@ -83,7 +90,7 @@ def debugprint(obj, depth=-1, print_type=False, file=None):
raise TypeError("debugprint cannot print an object of this type", obj) raise TypeError("debugprint cannot print an object of this type", obj)
for r in results_to_print: for r in results_to_print:
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order) file=_file, order=order, ids=ids)
if file is _file: if file is _file:
return file return file
elif file == 'str': elif file == 'str':
...@@ -904,31 +911,12 @@ class _TagGenerator: ...@@ -904,31 +911,12 @@ class _TagGenerator:
self.cur_tag_number = 0 self.cur_tag_number = 0
def get_tag(self): def get_tag(self):
rval = self.from_number(self.cur_tag_number) rval = debugmode.char_from_number(self.cur_tag_number)
self.cur_tag_number += 1 self.cur_tag_number += 1
return rval return rval
def from_number(self, number):
""" Converts number to string by rendering it in base 26 using
capital letters as digits """
base = 26
rval = ""
if number == 0:
rval = 'A'
while number != 0:
remainder = number % base
new_char = chr(ord('A') + remainder)
rval = new_char + rval
number /= base
return rval
def min_informative_str(obj, indent_level=0, def min_informative_str(obj, indent_level=0,
_prev_obs=None, _tag_generator=None): _prev_obs=None, _tag_generator=None):
......
...@@ -11,7 +11,7 @@ from nose.plugins.skip import SkipTest ...@@ -11,7 +11,7 @@ from nose.plugins.skip import SkipTest
import theano import theano
import theano.tensor as tensor import theano.tensor as tensor
from theano.printing import min_informative_str from theano.printing import min_informative_str, debugprint
def test_pydotprint_cond_highlight(): def test_pydotprint_cond_highlight():
...@@ -86,3 +86,78 @@ def test_min_informative_str(): ...@@ -86,3 +86,78 @@ def test_min_informative_str():
print '--' + reference + '--' print '--' + reference + '--'
assert mis == reference assert mis == reference
def test_debugprint():
A = tensor.matrix(name='A')
B = tensor.matrix(name='B')
C = A + B
C.name = 'C'
D = tensor.matrix(name='D')
E = tensor.matrix(name='E')
F = D + E
G = C + F
# just test that it work
debugprint(G)
# test ids=int
s = StringIO.StringIO()
debugprint(G, file=s, ids='int')
s = s.getvalue()
# The additional white space are needed!
reference = """Elemwise{add,no_inplace} [@0] ''
|Elemwise{add,no_inplace} [@1] 'C'
| |A [@2]
| |B [@3]
|Elemwise{add,no_inplace} [@4] ''
| |D [@5]
| |E [@6]
"""
if s != reference:
print '--'+s+'--'
print '--'+reference+'--'
assert s == reference
# test ids=CHAR
s = StringIO.StringIO()
debugprint(G, file=s, ids='CHAR')
s = s.getvalue()
# The additional white space are needed!
reference = """Elemwise{add,no_inplace} [@A] ''
|Elemwise{add,no_inplace} [@B] 'C'
| |A [@C]
| |B [@D]
|Elemwise{add,no_inplace} [@E] ''
| |D [@F]
| |E [@G]
"""
if s != reference:
print '--'+s+'--'
print '--'+reference+'--'
assert s == reference
# test ids=
s = StringIO.StringIO()
debugprint(G, file=s, ids='')
s = s.getvalue()
# The additional white space are needed!
reference = """Elemwise{add,no_inplace} ''
|Elemwise{add,no_inplace} 'C'
| |A
| |B
|Elemwise{add,no_inplace} ''
| |D
| |E
"""
if s != reference:
print '--'+s+'--'
print '--'+reference+'--'
assert s == reference
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论