提交 fd695b52 authored 作者: Frederic's avatar Frederic

Add debugprint param print_storage and fix order printing with multiple objects

上级 897ac042
...@@ -580,7 +580,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -580,7 +580,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
print_view_map=False, order=None, ids='CHAR', print_view_map=False, order=None, ids='CHAR',
stop_on_name=False, prefix_child=None, stop_on_name=False, prefix_child=None,
scan_ops=None, profile=None, scan_ops=None, profile=None,
scan_inner_to_outer_inputs=None): scan_inner_to_outer_inputs=None, smap=None):
""" """
Print the graph leading to `r` to given depth. Print the graph leading to `r` to given depth.
...@@ -691,21 +691,19 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -691,21 +691,19 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
if profile is None or a not in profile.apply_time: if profile is None or a not in profile.apply_time:
if len(a.outputs) == 1: if len(a.outputs) == 1:
print('%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op, idx = ""
id_str,
type_str,
r_name,
destroy_map_str,
view_map_str,
o), file=file)
else: else:
print('%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op, idx = ".%i" % a.outputs.index(r)
a.outputs.index(r), data = ""
id_str, type_str, if smap:
r_name, data = " " + str(smap[a.outputs[0]])
destroy_map_str, print('%s%s%s %s%s \'%s\' %s %s %s%s' % (prefix, a.op,
view_map_str, idx,
o), file=file) id_str, type_str,
r_name,
destroy_map_str,
view_map_str,
o, data), file=file)
else: else:
op_time = profile.apply_time[a] op_time = profile.apply_time[a]
op_time_percent = (op_time / profile.fct_call_time) * 100 op_time_percent = (op_time / profile.fct_call_time) * 100
...@@ -714,31 +712,21 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -714,31 +712,21 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100 tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100
if len(a.outputs) == 1: if len(a.outputs) == 1:
print("%s%s %s%s '%s' %s %s %s --> " idx = ""
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (prefix, a.op,
id_str,
type_str,
r_name,
destroy_map_str,
view_map_str,
o, op_time,
op_time_percent,
tot_time,
tot_time_percent), file=file)
else: else:
print("%s%s.%i %s%s '%s' %s %s %s --> " idx = ".%i" % a.outputs.index(r)
"%8.2es %4.1f%% %8.2es %4.1f%%" print("%s%s%s %s%s '%s' %s %s %s --> "
% (prefix, a.op, "%8.2es %4.1f%% %8.2es %4.1f%%"
a.outputs.index(r), % (prefix, a.op,
id_str, type_str, idx,
r_name, id_str, type_str,
destroy_map_str, r_name,
view_map_str, destroy_map_str,
o, op_time, view_map_str,
op_time_percent, o, op_time,
tot_time, op_time_percent,
tot_time_percent), file=file) tot_time,
tot_time_percent), file=file)
if not already_printed: if not already_printed:
if (not stop_on_name or if (not stop_on_name or
...@@ -761,7 +749,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -761,7 +749,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
ids=ids, stop_on_name=stop_on_name, ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, scan_ops=scan_ops, prefix_child=new_prefix_child, scan_ops=scan_ops,
profile=profile, profile=profile,
scan_inner_to_outer_inputs=scan_inner_to_outer_inputs) scan_inner_to_outer_inputs=scan_inner_to_outer_inputs,
smap=smap)
else: else:
if scan_inner_to_outer_inputs is not None and\ if scan_inner_to_outer_inputs is not None and\
r in scan_inner_to_outer_inputs: r in scan_inner_to_outer_inputs:
...@@ -777,8 +766,13 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -777,8 +766,13 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
outer_id_str), file=file) outer_id_str), file=file)
else: else:
# this is an input variable # this is an input variable
data = ""
if smap:
data = " " + str(smap[r])
id_str = get_id_str(r) id_str = get_id_str(r)
print('%s%s %s%s' % (prefix, r, id_str, type_str), file=file) print('%s%s %s%s%s' % (prefix, r, id_str,
type_str, data),
file=file)
return file return file
......
...@@ -48,7 +48,7 @@ VALID_ASSOC = set(['left', 'right', 'either']) ...@@ -48,7 +48,7 @@ VALID_ASSOC = set(['left', 'right', 'either'])
def debugprint(obj, depth=-1, print_type=False, def debugprint(obj, depth=-1, print_type=False,
file=None, ids='CHAR', stop_on_name=False, file=None, ids='CHAR', stop_on_name=False,
done=None): done=None, print_storage=False):
"""Print a computation graph as text to stdout or a file. """Print a computation graph as text to stdout or a file.
:type obj: Variable, Apply, or Function instance :type obj: Variable, Apply, or Function instance
...@@ -70,6 +70,10 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -70,6 +70,10 @@ def debugprint(obj, depth=-1, print_type=False,
:type done: None or dict :type done: None or dict
:param done: A dict where we store the ids of printed node. :param done: A dict where we store the ids of printed node.
Useful to have multiple call to debugprint share the same ids. Useful to have multiple call to debugprint share the same ids.
:type print_storage: bool
:param print_storage: If True, this will print the storage map
for Theano functions. Combined with allow_gc=False, after the
execution of a Theano function, we see the intermediate result.
:returns: string if `file` == 'str', else file arg :returns: string if `file` == 'str', else file arg
...@@ -101,7 +105,8 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -101,7 +105,8 @@ def debugprint(obj, depth=-1, print_type=False,
done = dict() done = dict()
results_to_print = [] results_to_print = []
profile_list = [] profile_list = []
order = [] order = [] # Toposort
smap = [] # storage_map
if isinstance(obj, (list, tuple, set)): if isinstance(obj, (list, tuple, set)):
lobj = obj lobj = obj
else: else:
...@@ -110,24 +115,41 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -110,24 +115,41 @@ def debugprint(obj, depth=-1, print_type=False,
if isinstance(obj, gof.Variable): if isinstance(obj, gof.Variable):
results_to_print.append(obj) results_to_print.append(obj)
profile_list.append(None) profile_list.append(None)
smap.append(None)
order.append(None)
elif isinstance(obj, gof.Apply): elif isinstance(obj, gof.Apply):
results_to_print.extend(obj.outputs) results_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs]) profile_list.extend([None for item in obj.outputs])
smap.extend([None for item in obj.outputs])
order.extend([None for item in obj.outputs])
elif isinstance(obj, Function): elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs) results_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend( profile_list.extend(
[obj.profile for item in obj.maker.fgraph.outputs]) [obj.profile for item in obj.maker.fgraph.outputs])
order = obj.maker.fgraph.toposort() if print_storage:
smap.extend(
[obj.fn.storage_map for item in obj.maker.fgraph.outputs])
else:
smap.extend(
[None for item in obj.maker.fgraph.outputs])
topo = obj.maker.fgraph.toposort()
order.extend(
[topo for item in obj.maker.fgraph.outputs])
elif isinstance(obj, gof.FunctionGraph): elif isinstance(obj, gof.FunctionGraph):
results_to_print.extend(obj.outputs) results_to_print.extend(obj.outputs)
profile_list.extend([getattr(obj, 'profile', None) profile_list.extend([getattr(obj, 'profile', None)
for item in obj.outputs]) for item in obj.outputs])
order = obj.toposort() smap.extend([getattr(obj, 'storage_map', None)
for item in obj.outputs])
topo = obj.toposort()
order.extend([topo for item in obj.outputs])
elif isinstance(obj, (integer_types, float, np.ndarray)): elif isinstance(obj, (integer_types, float, np.ndarray)):
print(obj) print(obj)
elif isinstance(obj, (theano.In, theano.Out)): elif isinstance(obj, (theano.In, theano.Out)):
results_to_print.append(obj.variable) results_to_print.append(obj.variable)
profile_list.append(None) profile_list.append(None)
smap.append(None)
order.append(None)
else: else:
raise TypeError("debugprint cannot print an object of this type", raise TypeError("debugprint cannot print an object of this type",
obj) obj)
...@@ -152,16 +174,16 @@ N.B.: ...@@ -152,16 +174,16 @@ N.B.:
to remove when optimizing a graph because their <total time> is very low. to remove when optimizing a graph because their <total time> is very low.
""", file=_file) """, file=_file)
for r, p in zip(results_to_print, profile_list): for r, p, s, o in zip(results_to_print, profile_list, smap, order):
# Add the parent scan op to the list as well # Add the parent scan op to the list as well
if (hasattr(r.owner, 'op') and if (hasattr(r.owner, 'op') and
isinstance(r.owner.op, theano.scan_module.scan_op.Scan)): isinstance(r.owner.op, theano.scan_module.scan_op.Scan)):
scan_ops.append(r) scan_ops.append(r)
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order, ids=ids, file=_file, order=o, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name, scan_ops=scan_ops, stop_on_name=stop_on_name,
profile=p) profile=p, smap=s)
if len(scan_ops) > 0: if len(scan_ops) > 0:
print("", file=_file) print("", file=_file)
......
...@@ -164,6 +164,8 @@ def test_debugprint(): ...@@ -164,6 +164,8 @@ def test_debugprint():
F = D + E F = D + E
G = C + F G = C + F
mode = theano.compile.get_default_mode().including('fusion')
g = theano.function([A, B, D, E], G, mode=mode)
# just test that it work # just test that it work
debugprint(G) debugprint(G)
...@@ -249,6 +251,24 @@ def test_debugprint(): ...@@ -249,6 +251,24 @@ def test_debugprint():
assert s == reference assert s == reference
# test print_storage=True
s = StringIO()
debugprint(g, file=s, ids='', print_storage=True)
s = s.getvalue()
# The additional white space are needed!
reference = '\n'.join([
"Elemwise{add,no_inplace} '' 0 [None]",
" |A [None]",
" |B [None]",
" |D [None]",
" |E [None]",
]) + '\n'
if s != reference:
print('--' + s + '--')
print('--' + reference + '--')
assert s == reference
def test_scan_debugprint1(): def test_scan_debugprint1():
k = tensor.iscalar("k") k = tensor.iscalar("k")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论