提交 9c0344cb authored 作者: Frederic Bastien's avatar Frederic Bastien

Make debugprint print the clients information

上级 b1c4abae
...@@ -512,7 +512,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -512,7 +512,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
print_view_map=False, order=None, ids='CHAR', print_view_map=False, order=None, ids='CHAR',
stop_on_name=False, prefix_child=None, stop_on_name=False, prefix_child=None,
scan_ops=None, profile=None, scan_ops=None, profile=None,
scan_inner_to_outer_inputs=None, smap=None): scan_inner_to_outer_inputs=None, smap=None,
used_ids=None):
""" """
Print the graph leading to `r` to given depth. Print the graph leading to `r` to given depth.
...@@ -575,20 +576,25 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -575,20 +576,25 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
if prefix_child is None: if prefix_child is None:
prefix_child = prefix prefix_child = prefix
if used_ids is None:
used_ids = dict()
def get_id_str(obj, get_printed=True): def get_id_str(obj, get_printed=True):
if obj in done: if obj in used_ids:
id_str = done[obj] id_str = used_ids[obj]
elif obj == 'output':
id_str = 'output'
elif ids == "id": elif ids == "id":
id_str = "[id %s]" % str(id(r)) id_str = "[id %s]" % str(id(r))
elif ids == "int": elif ids == "int":
id_str = "[id %s]" % str(len(done)) id_str = "[id %s]" % str(len(used_ids))
elif ids == "CHAR": elif ids == "CHAR":
id_str = "[id %s]" % char_from_number(len(done)) id_str = "[id %s]" % char_from_number(len(used_ids))
elif ids == "": elif ids == "":
id_str = "" id_str = ""
if get_printed: if get_printed:
done[obj] = id_str done[obj] = id_str
used_ids[obj] = id_str
return id_str return id_str
if hasattr(r.owner, 'op'): if hasattr(r.owner, 'op'):
...@@ -636,10 +642,9 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -636,10 +642,9 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
try: try:
return order.index(c) return order.index(c)
except ValueError: except ValueError:
return -1 return ""
clients = " clients:" + str([(get_id_str(c, False), get_index(c)) clients = " clients:" + str([(get_id_str(c, False), get_index(c))
for c,i in r.clients for c,i in r.clients])
if c != 'output'])
if profile is None or a not in profile.apply_time: if profile is None or a not in profile.apply_time:
print('%s%s%s %s%s \'%s\' %s %s %s%s%s' % (prefix, a.op, print('%s%s%s %s%s \'%s\' %s %s %s%s%s' % (prefix, a.op,
idx, idx,
...@@ -648,8 +653,6 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -648,8 +653,6 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
destroy_map_str, destroy_map_str,
view_map_str, view_map_str,
o, data, clients), file=file) o, data, clients), file=file)
# if len(r.clients) > 1:
# import pdb;pdb.set_trace()
else: else:
op_time = profile.apply_time[a] op_time = profile.apply_time[a]
op_time_percent = (op_time / profile.fct_call_time) * 100 op_time_percent = (op_time / profile.fct_call_time) * 100
...@@ -661,7 +664,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -661,7 +664,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
idx = "" idx = ""
else: else:
idx = ".%i" % a.outputs.index(r) idx = ".%i" % a.outputs.index(r)
print("%s%s%s %s%s '%s' %s %s %s%s --> " print("%s%s%s %s%s '%s' %s %s %s%s%s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%" "%8.2es %4.1f%% %8.2es %4.1f%%"
% (prefix, a.op, % (prefix, a.op,
idx, idx,
...@@ -669,7 +672,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -669,7 +672,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
r_name, r_name,
destroy_map_str, destroy_map_str,
view_map_str, view_map_str,
o, data, o, data, clients,
op_time, op_time,
op_time_percent, op_time_percent,
tot_time, tot_time,
...@@ -697,7 +700,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -697,7 +700,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
prefix_child=new_prefix_child, scan_ops=scan_ops, prefix_child=new_prefix_child, scan_ops=scan_ops,
profile=profile, profile=profile,
scan_inner_to_outer_inputs=scan_inner_to_outer_inputs, scan_inner_to_outer_inputs=scan_inner_to_outer_inputs,
smap=smap) smap=smap, used_ids=used_ids)
else: else:
if scan_inner_to_outer_inputs is not None and\ if scan_inner_to_outer_inputs is not None and\
r in scan_inner_to_outer_inputs: r in scan_inner_to_outer_inputs:
......
...@@ -98,6 +98,7 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -98,6 +98,7 @@ def debugprint(obj, depth=-1, print_type=False,
_file = file _file = file
if done is None: if done is None:
done = dict() done = dict()
used_ids = dict()
results_to_print = [] results_to_print = []
profile_list = [] profile_list = []
order = [] # Toposort order = [] # Toposort
...@@ -178,7 +179,7 @@ N.B.: ...@@ -178,7 +179,7 @@ N.B.:
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=o, ids=ids, file=_file, order=o, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name, scan_ops=scan_ops, stop_on_name=stop_on_name,
profile=p, smap=s) profile=p, smap=s, used_ids=used_ids)
if len(scan_ops) > 0: if len(scan_ops) > 0:
print("", file=_file) print("", file=_file)
......
...@@ -270,6 +270,26 @@ def test_debugprint(): ...@@ -270,6 +270,26 @@ def test_debugprint():
assert s == reference assert s == reference
# test clients
s = StringIO()
f = theano.function([A, B, D], [A + B, A + B - D])
debugprint(f, file=s)
s = s.getvalue()
# The additional white space are needed!
reference = '\n'.join([
"Elemwise{add,no_inplace} [id A] '' 0 clients:[('output', ''), ('[id C]', 1)]",
" |A [id D]",
" |B [id E]",
"Elemwise{sub,no_inplace} [id C] '' 1",
" |Elemwise{add,no_inplace} [id A] '' 0 clients:[('output', ''), ('[id C]', 1)]",
" |D [id F]",
]) + '\n'
if s != reference:
print('--' + s + '--')
print('--' + reference + '--')
assert s == reference
def test_scan_debugprint1(): def test_scan_debugprint1():
k = tensor.iscalar("k") k = tensor.iscalar("k")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论