提交 9048d63f authored 作者: vdumoulin's avatar vdumoulin

Merge pull request #2301 from mohammadpz/timing_info

timing information added in debugprint
...@@ -495,7 +495,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -495,7 +495,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
file=sys.stdout, print_destroy_map=False, file=sys.stdout, print_destroy_map=False,
print_view_map=False, order=None, ids='CHAR', print_view_map=False, order=None, ids='CHAR',
stop_on_name=False, prefix_child=None, stop_on_name=False, prefix_child=None,
scan_ops=None): scan_ops=None, profile=None):
"""Print the graph leading to `r` to given depth. """Print the graph leading to `r` to given depth.
:param r: Variable instance :param r: Variable instance
...@@ -585,22 +585,55 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -585,22 +585,55 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
already_printed = a in done # get_id_str put it in the dict already_printed = a in done # get_id_str put it in the dict
id_str = get_id_str(a) id_str = get_id_str(a)
if len(a.outputs) == 1: if profile == None:
print >> file, '%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op, if len(a.outputs) == 1:
id_str, print >> file, '%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op,
type_str, id_str,
r_name, type_str,
destroy_map_str, r_name,
view_map_str, destroy_map_str,
o) view_map_str,
o)
else:
print >> file, '%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op,
a.outputs.index(r),
id_str, type_str,
r_name,
destroy_map_str,
view_map_str,
o)
else: else:
print >> file, '%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op, op_time = profile.apply_time[a]
a.outputs.index(r), op_time_percent = (op_time / profile.fct_call_time) * 100
id_str, type_str, tot_time_dict = profile.compute_total_times()
r_name, tot_time = tot_time_dict[a]
destroy_map_str, tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100
view_map_str,
o) if len(a.outputs) == 1:
print >> file, '%s%s %s%s \'%s\' %s %s %s --> %8.2es %4.1f%% %8.2es %4.1f%%'\
% (prefix, a.op,
id_str,
type_str,
r_name,
destroy_map_str,
view_map_str,
o, op_time,
op_time_percent,
tot_time,
tot_time_percent)
else:
print >> file, '%s%s.%i %s%s \'%s\' %s %s %s --> %8.2es %4.1f%% %8.2es %4.1f%%'\
% (prefix, a.op,
a.outputs.index(r),
id_str, type_str,
r_name,
destroy_map_str,
view_map_str,
o, op_time,
op_time_percent,
tot_time,
tot_time_percent)
if not already_printed: if not already_printed:
if (not stop_on_name or if (not stop_on_name or
not (hasattr(r, 'name') and r.name is not None)): not (hasattr(r, 'name') and r.name is not None)):
...@@ -618,7 +651,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -618,7 +651,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
debugprint(i, new_prefix, depth=depth - 1, done=done, debugprint(i, new_prefix, depth=depth - 1, done=done,
print_type=print_type, file=file, order=order, print_type=print_type, file=file, order=order,
ids=ids, stop_on_name=stop_on_name, ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, scan_ops=scan_ops) prefix_child=new_prefix_child, scan_ops=scan_ops,
profile=profile)
else: else:
#this is an input variable #this is an input variable
......
...@@ -294,6 +294,25 @@ class ProfileStats(object): ...@@ -294,6 +294,25 @@ class ProfileStats(object):
rval[node.op] += t rval[node.op] += t
return rval return rval
def fill_node_total_time(self, node, total_times):
"""node -> fill total time icluding its parents (returns nothing)"""
# timing is stored by node, we compute total time on demand
total = self.apply_time[node]
for parent in node.get_parents():
if parent.owner in self.apply_time.keys():
if parent.owner not in total_times.keys():
self.fill_node_total_time(parent.owner, total_times)
total += total_times[parent.owner]
total_times[node] = total
def compute_total_times(self):
"""dict op -> total time icluding the time for parents"""
rval = {}
for node in self.apply_time.keys():
if node not in rval:
self.fill_node_total_time(node, rval)
return rval
def op_callcount(self): def op_callcount(self):
"""dict op -> total number of thunk calls""" """dict op -> total number of thunk calls"""
# timing is stored by node, we compute timing by Op on demand # timing is stored by node, we compute timing by Op on demand
......
...@@ -81,6 +81,7 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -81,6 +81,7 @@ def debugprint(obj, depth=-1, print_type=False,
_file = file _file = file
done = dict() done = dict()
results_to_print = [] results_to_print = []
profile_list = []
order = [] order = []
if isinstance(obj, (list, tuple)): if isinstance(obj, (list, tuple)):
lobj = obj lobj = obj
...@@ -89,32 +90,57 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -89,32 +90,57 @@ def debugprint(obj, depth=-1, print_type=False,
for obj in lobj: for obj in lobj:
if isinstance(obj, gof.Variable): if isinstance(obj, gof.Variable):
results_to_print.append(obj) results_to_print.append(obj)
profile_list.append(None)
elif isinstance(obj, gof.Apply): elif isinstance(obj, gof.Apply):
results_to_print.extend(obj.outputs) results_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs])
elif isinstance(obj, Function): elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs) results_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs])
order = obj.maker.fgraph.toposort() order = obj.maker.fgraph.toposort()
elif isinstance(obj, gof.FunctionGraph): elif isinstance(obj, gof.FunctionGraph):
results_to_print.extend(obj.outputs) results_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs])
order = obj.toposort() order = obj.toposort()
elif isinstance(obj, (int, long, float, numpy.ndarray)): elif isinstance(obj, (int, long, float, numpy.ndarray)):
print obj print obj
elif isinstance(obj, (theano.In, theano.Out)): elif isinstance(obj, (theano.In, theano.Out)):
results_to_print.append(obj.variable) results_to_print.append(obj.variable)
profile_list.append(None)
else: else:
raise TypeError("debugprint cannot print an object of this type", raise TypeError("debugprint cannot print an object of this type",
obj) obj)
scan_ops = [] scan_ops = []
for r in results_to_print: for r, p in zip(results_to_print, profile_list):
# Add the parent scan op to the list as well # Add the parent scan op to the list as well
if (hasattr(r.owner, 'op') and if (hasattr(r.owner, 'op') and
isinstance(r.owner.op, theano.scan_module.scan_op.Scan)): isinstance(r.owner.op, theano.scan_module.scan_op.Scan)):
scan_ops.append(r) scan_ops.append(r)
if p != None:
print >> file, """
Timing Info
-----------
--> <time> <% time> - <total time> <% total time>'
<time> computation time for this node
<% time> fraction of total computation time for this node
<total time> time for this node + total times for this node's ancestors
<% total time> total time for this node over total computation time
N.B.:
* Times include the node time and the function overhead.
* <total time> and <% total time> may over-count computation times
if inputs to a node share a common ancestor and should be viewed as a
loose upper bound. Their intended use is to help rule out potential nodes
to remove when optimizing a graph because their <total time> is very low.
"""
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order, ids=ids, file=_file, order=order, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name) scan_ops=scan_ops, stop_on_name=stop_on_name,
profile=p)
if len(scan_ops) > 0: if len(scan_ops) > 0:
print >> file, "" print >> file, ""
new_prefix = ' >' new_prefix = ' >'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论