提交 233eb68e authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

timing information added

上级 aef67878
...@@ -495,7 +495,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -495,7 +495,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
file=sys.stdout, print_destroy_map=False, file=sys.stdout, print_destroy_map=False,
print_view_map=False, order=None, ids='CHAR', print_view_map=False, order=None, ids='CHAR',
stop_on_name=False, prefix_child=None, stop_on_name=False, prefix_child=None,
scan_ops=None): scan_ops=None, profile=None):
"""Print the graph leading to `r` to given depth. """Print the graph leading to `r` to given depth.
:param r: Variable instance :param r: Variable instance
...@@ -585,22 +585,55 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -585,22 +585,55 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
already_printed = a in done # get_id_str put it in the dict already_printed = a in done # get_id_str put it in the dict
id_str = get_id_str(a) id_str = get_id_str(a)
if len(a.outputs) == 1: if profile == None:
print >> file, '%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op, if len(a.outputs) == 1:
id_str, print >> file, '%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op,
type_str, id_str,
r_name, type_str,
destroy_map_str, r_name,
view_map_str, destroy_map_str,
o) view_map_str,
o)
else:
print >> file, '%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op,
a.outputs.index(r),
id_str, type_str,
r_name,
destroy_map_str,
view_map_str,
o)
else: else:
print >> file, '%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op, op_time = profile.apply_time[a]
a.outputs.index(r), op_time_percent = (op_time / profile.fct_call_time) * 100
id_str, type_str, tot_time_dict = profile.compute_total_times()
r_name, tot_time = tot_time_dict[a]
destroy_map_str, tot_time_percent = (tot_time / max(tot_time_dict.values())) * 100
view_map_str,
o) if len(a.outputs) == 1:
print >> file, '%s%s %s%s \'%s\' %s %s %s --> %8.2es %4.1f%% %8.2es %4.1f%%'\
% (prefix, a.op,
id_str,
type_str,
r_name,
destroy_map_str,
view_map_str,
o, op_time,
op_time_percent,
tot_time,
tot_time_percent)
else:
print >> file, '%s%s.%i %s%s \'%s\' %s %s %s --> %8.2es %4.1f%% %8.2es %4.1f%%'\
% (prefix, a.op,
a.outputs.index(r),
id_str, type_str,
r_name,
destroy_map_str,
view_map_str,
o, op_time,
op_time_percent,
tot_time,
tot_time_percent)
if not already_printed: if not already_printed:
if (not stop_on_name or if (not stop_on_name or
not (hasattr(r, 'name') and r.name is not None)): not (hasattr(r, 'name') and r.name is not None)):
...@@ -618,7 +651,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -618,7 +651,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
debugprint(i, new_prefix, depth=depth - 1, done=done, debugprint(i, new_prefix, depth=depth - 1, done=done,
print_type=print_type, file=file, order=order, print_type=print_type, file=file, order=order,
ids=ids, stop_on_name=stop_on_name, ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, scan_ops=scan_ops) prefix_child=new_prefix_child, scan_ops=scan_ops,
profile=profile)
else: else:
#this is an input variable #this is an input variable
......
...@@ -294,6 +294,19 @@ class ProfileStats(object): ...@@ -294,6 +294,19 @@ class ProfileStats(object):
rval[node.op] += t rval[node.op] += t
return rval return rval
def get_node_total_time(self, node):
total = self.apply_time[node]
for parent in node.get_parents():
if parent.owner in self.apply_time.keys():
total += self.get_node_total_time(parent.owner)
return total
def compute_total_times(self):
rval = {}
for node in self.apply_time.keys():
rval[node] = self.get_node_total_time(node)
return rval
def op_callcount(self): def op_callcount(self):
"""dict op -> total number of thunk calls""" """dict op -> total number of thunk calls"""
# timing is stored by node, we compute timing by Op on demand # timing is stored by node, we compute timing by Op on demand
......
...@@ -35,7 +35,7 @@ _logger = logging.getLogger("theano.printing") ...@@ -35,7 +35,7 @@ _logger = logging.getLogger("theano.printing")
def debugprint(obj, depth=-1, print_type=False, def debugprint(obj, depth=-1, print_type=False,
file=None, ids='CHAR', stop_on_name=False): file=None, ids='CHAR', stop_on_name=False, profile=None):
"""Print a computation graph as text to stdout or a file. """Print a computation graph as text to stdout or a file.
:type obj: Variable, Apply, or Function instance :type obj: Variable, Apply, or Function instance
...@@ -94,6 +94,10 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -94,6 +94,10 @@ def debugprint(obj, depth=-1, print_type=False,
elif isinstance(obj, Function): elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs) results_to_print.extend(obj.maker.fgraph.outputs)
order = obj.maker.fgraph.toposort() order = obj.maker.fgraph.toposort()
profile=obj.profile
if profile != None:
print 'Timing Info\n-----------\n\t \
--> <time> <% time> - <total time> <% total time>'
elif isinstance(obj, gof.FunctionGraph): elif isinstance(obj, gof.FunctionGraph):
results_to_print.extend(obj.outputs) results_to_print.extend(obj.outputs)
order = obj.toposort() order = obj.toposort()
...@@ -114,7 +118,8 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -114,7 +118,8 @@ def debugprint(obj, depth=-1, print_type=False,
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order, ids=ids, file=_file, order=order, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name) scan_ops=scan_ops, stop_on_name=stop_on_name,
profile=profile)
if len(scan_ops) > 0: if len(scan_ops) > 0:
print >> file, "" print >> file, ""
new_prefix = ' >' new_prefix = ' >'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论