提交 9c0344cb authored 作者: Frederic Bastien's avatar Frederic Bastien

Make debugprint print the clients information

上级 b1c4abae
......@@ -512,7 +512,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
print_view_map=False, order=None, ids='CHAR',
stop_on_name=False, prefix_child=None,
scan_ops=None, profile=None,
scan_inner_to_outer_inputs=None, smap=None):
scan_inner_to_outer_inputs=None, smap=None,
used_ids=None):
"""
Print the graph leading to `r` to given depth.
......@@ -575,20 +576,25 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
if prefix_child is None:
prefix_child = prefix
if used_ids is None:
used_ids = dict()
def get_id_str(obj, get_printed=True):
if obj in done:
id_str = done[obj]
if obj in used_ids:
id_str = used_ids[obj]
elif obj == 'output':
id_str = 'output'
elif ids == "id":
id_str = "[id %s]" % str(id(r))
elif ids == "int":
id_str = "[id %s]" % str(len(done))
id_str = "[id %s]" % str(len(used_ids))
elif ids == "CHAR":
id_str = "[id %s]" % char_from_number(len(done))
id_str = "[id %s]" % char_from_number(len(used_ids))
elif ids == "":
id_str = ""
if get_printed:
done[obj] = id_str
used_ids[obj] = id_str
return id_str
if hasattr(r.owner, 'op'):
......@@ -636,10 +642,9 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
try:
return order.index(c)
except ValueError:
return -1
return ""
clients = " clients:" + str([(get_id_str(c, False), get_index(c))
for c,i in r.clients
if c != 'output'])
for c,i in r.clients])
if profile is None or a not in profile.apply_time:
print('%s%s%s %s%s \'%s\' %s %s %s%s%s' % (prefix, a.op,
idx,
......@@ -648,8 +653,6 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
destroy_map_str,
view_map_str,
o, data, clients), file=file)
# if len(r.clients) > 1:
# import pdb;pdb.set_trace()
else:
op_time = profile.apply_time[a]
op_time_percent = (op_time / profile.fct_call_time) * 100
......@@ -661,7 +664,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
idx = ""
else:
idx = ".%i" % a.outputs.index(r)
print("%s%s%s %s%s '%s' %s %s %s%s --> "
print("%s%s%s %s%s '%s' %s %s %s%s%s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (prefix, a.op,
idx,
......@@ -669,7 +672,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
r_name,
destroy_map_str,
view_map_str,
o, data,
o, data, clients,
op_time,
op_time_percent,
tot_time,
......@@ -697,7 +700,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
prefix_child=new_prefix_child, scan_ops=scan_ops,
profile=profile,
scan_inner_to_outer_inputs=scan_inner_to_outer_inputs,
smap=smap)
smap=smap, used_ids=used_ids)
else:
if scan_inner_to_outer_inputs is not None and\
r in scan_inner_to_outer_inputs:
......
......@@ -98,6 +98,7 @@ def debugprint(obj, depth=-1, print_type=False,
_file = file
if done is None:
done = dict()
used_ids = dict()
results_to_print = []
profile_list = []
order = [] # Toposort
......@@ -178,7 +179,7 @@ N.B.:
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=o, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name,
profile=p, smap=s)
profile=p, smap=s, used_ids=used_ids)
if len(scan_ops) > 0:
print("", file=_file)
......
......@@ -270,6 +270,26 @@ def test_debugprint():
assert s == reference
# test clients
s = StringIO()
f = theano.function([A, B, D], [A + B, A + B - D])
debugprint(f, file=s)
s = s.getvalue()
# The additional white space are needed!
reference = '\n'.join([
"Elemwise{add,no_inplace} [id A] '' 0 clients:[('output', ''), ('[id C]', 1)]",
" |A [id D]",
" |B [id E]",
"Elemwise{sub,no_inplace} [id C] '' 1",
" |Elemwise{add,no_inplace} [id A] '' 0 clients:[('output', ''), ('[id C]', 1)]",
" |D [id F]",
]) + '\n'
if s != reference:
print('--' + s + '--')
print('--' + reference + '--')
assert s == reference
def test_scan_debugprint1():
k = tensor.iscalar("k")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论