提交 0cffdb28 authored 作者: Frederic's avatar Frederic

Added debugprint the parameter stop_on_name.

上级 9c407da9
......@@ -502,7 +502,7 @@ def char_from_number(number):
def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
file=sys.stdout, print_destroy_map=False, print_view_map=False,
order=[], ids='CHAR'):
order=[], ids='CHAR', stop_on_name=False):
"""Print the graph leading to `r` to given depth.
:param r: Variable instance
......@@ -520,6 +520,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
int - print integer character
CHAR - print capital character
"" - don't print an identifier
:param stop_on_name: When True, if a node in the graph have a name,
we don't print anything below it.
"""
if depth == 0:
......@@ -594,10 +596,12 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
view_map_str,
o)
if not already_printed:
for i in a.inputs:
debugprint(i, prefix + ' |', depth=depth - 1, done=done,
print_type=print_type, file=file, order=order,
ids=ids)
if (not stop_on_name or
not (hasattr(r, 'name') and r.name is not None)):
for i in a.inputs:
debugprint(i, prefix + ' |', depth=depth - 1, done=done,
print_type=print_type, file=file, order=order,
ids=ids, stop_on_name=stop_on_name)
else:
#this is an input variable
id_str = get_id_str(r)
......
......@@ -29,7 +29,7 @@ _logger = logging.getLogger("theano.printing")
def debugprint(obj, depth=-1, print_type=False,
file=None, ids='CHAR'):
file=None, ids='CHAR', stop_on_name=False):
"""Print a computation graph to file
:type obj: Variable, Apply, or Function instance
......@@ -46,6 +46,8 @@ def debugprint(obj, depth=-1, print_type=False,
CHAR - print capital character
CHAR - print capital character
"" - don't print an identifier
:param stop_on_name: When True, if a node in the graph have a name,
we don't print anything below it.
:returns: string if `file` == 'str', else file arg
......@@ -90,7 +92,8 @@ def debugprint(obj, depth=-1, print_type=False,
raise TypeError("debugprint cannot print an object of this type", obj)
for r in results_to_print:
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order, ids=ids)
file=_file, order=order, ids=ids,
stop_on_name=stop_on_name)
if file is _file:
return file
elif file == 'str':
......
......@@ -143,6 +143,25 @@ def test_debugprint():
assert s == reference
# test ids=CHAR, stop_on_name=True
s = StringIO.StringIO()
debugprint(G, file=s, ids='CHAR', stop_on_name=True)
s = s.getvalue()
# The additional white space are needed!
reference = """Elemwise{add,no_inplace} [@A] ''
|Elemwise{add,no_inplace} [@B] 'C'
|Elemwise{add,no_inplace} [@C] ''
| |D [@D]
| |E [@E]
"""
if s != reference:
print '--'+s+'--'
print '--'+reference+'--'
assert s == reference
# test ids=
s = StringIO.StringIO()
debugprint(G, file=s, ids='')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论