提交 fd695b52 authored 作者: Frederic's avatar Frederic

Add debugprint param print_storage and fix order printing with multiple objects

上级 897ac042
......@@ -580,7 +580,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
print_view_map=False, order=None, ids='CHAR',
stop_on_name=False, prefix_child=None,
scan_ops=None, profile=None,
scan_inner_to_outer_inputs=None):
scan_inner_to_outer_inputs=None, smap=None):
"""
Print the graph leading to `r` to given depth.
......@@ -691,21 +691,19 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
if profile is None or a not in profile.apply_time:
if len(a.outputs) == 1:
print('%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op,
id_str,
type_str,
r_name,
destroy_map_str,
view_map_str,
o), file=file)
idx = ""
else:
print('%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op,
a.outputs.index(r),
id_str, type_str,
r_name,
destroy_map_str,
view_map_str,
o), file=file)
idx = ".%i" % a.outputs.index(r)
data = ""
if smap:
data = " " + str(smap[a.outputs[0]])
print('%s%s%s %s%s \'%s\' %s %s %s%s' % (prefix, a.op,
idx,
id_str, type_str,
r_name,
destroy_map_str,
view_map_str,
o, data), file=file)
else:
op_time = profile.apply_time[a]
op_time_percent = (op_time / profile.fct_call_time) * 100
......@@ -714,31 +712,21 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100
if len(a.outputs) == 1:
print("%s%s %s%s '%s' %s %s %s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (prefix, a.op,
id_str,
type_str,
r_name,
destroy_map_str,
view_map_str,
o, op_time,
op_time_percent,
tot_time,
tot_time_percent), file=file)
idx = ""
else:
print("%s%s.%i %s%s '%s' %s %s %s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (prefix, a.op,
a.outputs.index(r),
id_str, type_str,
r_name,
destroy_map_str,
view_map_str,
o, op_time,
op_time_percent,
tot_time,
tot_time_percent), file=file)
idx = ".%i" % a.outputs.index(r)
print("%s%s%s %s%s '%s' %s %s %s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (prefix, a.op,
idx,
id_str, type_str,
r_name,
destroy_map_str,
view_map_str,
o, op_time,
op_time_percent,
tot_time,
tot_time_percent), file=file)
if not already_printed:
if (not stop_on_name or
......@@ -761,7 +749,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, scan_ops=scan_ops,
profile=profile,
scan_inner_to_outer_inputs=scan_inner_to_outer_inputs)
scan_inner_to_outer_inputs=scan_inner_to_outer_inputs,
smap=smap)
else:
if scan_inner_to_outer_inputs is not None and\
r in scan_inner_to_outer_inputs:
......@@ -777,8 +766,13 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
outer_id_str), file=file)
else:
# this is an input variable
data = ""
if smap:
data = " " + str(smap[r])
id_str = get_id_str(r)
print('%s%s %s%s' % (prefix, r, id_str, type_str), file=file)
print('%s%s %s%s%s' % (prefix, r, id_str,
type_str, data),
file=file)
return file
......
......@@ -48,7 +48,7 @@ VALID_ASSOC = set(['left', 'right', 'either'])
def debugprint(obj, depth=-1, print_type=False,
file=None, ids='CHAR', stop_on_name=False,
done=None):
done=None, print_storage=False):
"""Print a computation graph as text to stdout or a file.
:type obj: Variable, Apply, or Function instance
......@@ -70,6 +70,10 @@ def debugprint(obj, depth=-1, print_type=False,
:type done: None or dict
:param done: A dict where we store the ids of printed node.
Useful to have multiple call to debugprint share the same ids.
:type print_storage: bool
:param print_storage: If True, this will print the storage map
for Theano functions. Combined with allow_gc=False, after the
execution of a Theano function, we see the intermediate result.
:returns: string if `file` == 'str', else file arg
......@@ -101,7 +105,8 @@ def debugprint(obj, depth=-1, print_type=False,
done = dict()
results_to_print = []
profile_list = []
order = []
order = [] # Toposort
smap = [] # storage_map
if isinstance(obj, (list, tuple, set)):
lobj = obj
else:
......@@ -110,24 +115,41 @@ def debugprint(obj, depth=-1, print_type=False,
if isinstance(obj, gof.Variable):
results_to_print.append(obj)
profile_list.append(None)
smap.append(None)
order.append(None)
elif isinstance(obj, gof.Apply):
results_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs])
smap.extend([None for item in obj.outputs])
order.extend([None for item in obj.outputs])
elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend(
[obj.profile for item in obj.maker.fgraph.outputs])
order = obj.maker.fgraph.toposort()
if print_storage:
smap.extend(
[obj.fn.storage_map for item in obj.maker.fgraph.outputs])
else:
smap.extend(
[None for item in obj.maker.fgraph.outputs])
topo = obj.maker.fgraph.toposort()
order.extend(
[topo for item in obj.maker.fgraph.outputs])
elif isinstance(obj, gof.FunctionGraph):
results_to_print.extend(obj.outputs)
profile_list.extend([getattr(obj, 'profile', None)
for item in obj.outputs])
order = obj.toposort()
smap.extend([getattr(obj, 'storage_map', None)
for item in obj.outputs])
topo = obj.toposort()
order.extend([topo for item in obj.outputs])
elif isinstance(obj, (integer_types, float, np.ndarray)):
print(obj)
elif isinstance(obj, (theano.In, theano.Out)):
results_to_print.append(obj.variable)
profile_list.append(None)
smap.append(None)
order.append(None)
else:
raise TypeError("debugprint cannot print an object of this type",
obj)
......@@ -152,16 +174,16 @@ N.B.:
to remove when optimizing a graph because their <total time> is very low.
""", file=_file)
for r, p in zip(results_to_print, profile_list):
for r, p, s, o in zip(results_to_print, profile_list, smap, order):
# Add the parent scan op to the list as well
if (hasattr(r.owner, 'op') and
isinstance(r.owner.op, theano.scan_module.scan_op.Scan)):
scan_ops.append(r)
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order, ids=ids,
file=_file, order=o, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name,
profile=p)
profile=p, smap=s)
if len(scan_ops) > 0:
print("", file=_file)
......
......@@ -164,6 +164,8 @@ def test_debugprint():
F = D + E
G = C + F
mode = theano.compile.get_default_mode().including('fusion')
g = theano.function([A, B, D, E], G, mode=mode)
# just test that it work
debugprint(G)
......@@ -249,6 +251,24 @@ def test_debugprint():
assert s == reference
# test print_storage=True
s = StringIO()
debugprint(g, file=s, ids='', print_storage=True)
s = s.getvalue()
# The additional white space are needed!
reference = '\n'.join([
"Elemwise{add,no_inplace} '' 0 [None]",
" |A [None]",
" |B [None]",
" |D [None]",
" |E [None]",
]) + '\n'
if s != reference:
print('--' + s + '--')
print('--' + reference + '--')
assert s == reference
def test_scan_debugprint1():
k = tensor.iscalar("k")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论