提交 7d4616b9 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

computing total time for children tree is now more efficient

上级 74ee0d01
...@@ -294,17 +294,20 @@ class ProfileStats(object): ...@@ -294,17 +294,20 @@ class ProfileStats(object):
rval[node.op] += t rval[node.op] += t
return rval return rval
def get_node_total_time(self, node): def get_node_total_time(self, node, total_times):
total = self.apply_time[node] if node not in total_times.keys():
for parent in node.get_parents(): total = self.apply_time[node]
if parent.owner in self.apply_time.keys(): for parent in node.get_parents():
total += self.get_node_total_time(parent.owner) if parent.owner in self.apply_time.keys():
return total total += self.get_node_total_time(parent.owner)
return total
else:
return total_times[node]
def compute_total_times(self): def compute_total_times(self):
rval = {} rval = {}
for node in self.apply_time.keys(): for node in self.apply_time.keys():
rval[node] = self.get_node_total_time(node) rval[node] = self.get_node_total_time(node, rval)
return rval return rval
def op_callcount(self): def op_callcount(self):
......
...@@ -115,21 +115,20 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -115,21 +115,20 @@ def debugprint(obj, depth=-1, print_type=False,
obj) obj)
scan_ops = [] scan_ops = []
for r in results_to_print: for r, p in zip(results_to_print, profile_list):
# Add the parent scan op to the list as well # Add the parent scan op to the list as well
if (hasattr(r.owner, 'op') and if (hasattr(r.owner, 'op') and
isinstance(r.owner.op, theano.scan_module.scan_op.Scan)): isinstance(r.owner.op, theano.scan_module.scan_op.Scan)):
scan_ops.append(r) scan_ops.append(r)
profile = obj.profile if p != None:
if profile != None:
print 'Timing Info\n-----------\n\t \ print 'Timing Info\n-----------\n\t \
--> <time> <% time> - <total time> <% total time>' --> <time> <% time> - <total time> <% total time>'
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order, ids=ids, file=_file, order=order, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name, scan_ops=scan_ops, stop_on_name=stop_on_name,
profile=profile) profile=p)
if len(scan_ops) > 0: if len(scan_ops) > 0:
print >> file, "" print >> file, ""
new_prefix = ' >' new_prefix = ' >'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论