提交 441f01cf authored 作者: Roy Xue's avatar Roy Xue

updates

上级 5fb8f003
...@@ -697,14 +697,14 @@ class ProfileStats(object): ...@@ -697,14 +697,14 @@ class ProfileStats(object):
order_index = 0 order_index = 0
order = [] order = []
node_list = list(node_list) node_list = list(node_list)
min_mem = 0 min_mem = sys.maxint
current_mem = 0 current_mem = 0
compute_map = defaultdict(lambda: [0]) compute_map = defaultdict(lambda: [0])
# compute_map use to check if a node is valid # compute_map use to check if a node is valid
for node in fgraph.inputs: for node in node_list:
compute_map[node][0] = 1 for val in node.inputs:
compute_map[val][0] = 1
def check_node_state(node): def check_node_state(node):
""" """
...@@ -715,7 +715,7 @@ class ProfileStats(object): ...@@ -715,7 +715,7 @@ class ProfileStats(object):
inputs = node.inputs inputs = node.inputs
outputs = node.outputs outputs = node.outputs
deps = inputs + node.destroy_dependencies deps = inputs + node.destroy_dependencies
computed_ins = all(compute_map[v][0] for v in deps) computed_ins = all(compute_map[v][0] for v in inputs)
computed_outs = all(compute_map[v][0] for v in outputs) computed_outs = all(compute_map[v][0] for v in outputs)
# check if there could be a compute_map # check if there could be a compute_map
if computed_ins and not computed_outs: if computed_ins and not computed_outs:
...@@ -723,7 +723,7 @@ class ProfileStats(object): ...@@ -723,7 +723,7 @@ class ProfileStats(object):
else: else:
return False return False
def min_memory_generator(node_list, compute_map): def min_memory_generator(node_list):
''' '''
enumerate all valid orders for the list of nodes in node_list enumerate all valid orders for the list of nodes in node_list
compute the peak of all order and keep the order with the minimum peak. compute the peak of all order and keep the order with the minimum peak.
...@@ -733,7 +733,6 @@ class ProfileStats(object): ...@@ -733,7 +733,6 @@ class ProfileStats(object):
:param compute_map: simulate the node execution steps to update compute_map :param compute_map: simulate the node execution steps to update compute_map
''' '''
for i in range(len(node_list)): for i in range(len(node_list)):
v = node_list[i:i+1] v = node_list[i:i+1]
if check_node_state(v[0]): if check_node_state(v[0]):
...@@ -743,14 +742,14 @@ class ProfileStats(object): ...@@ -743,14 +742,14 @@ class ProfileStats(object):
yield v yield v
else: else:
rest = node_list[ :i] + node_list[i+1: ] rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest, compute_map): for p in min_memory_generator(rest):
yield v+p yield v+p
for node in v[0].outputs: for node in v[0].outputs:
compute_map[node][0] = 0 compute_map[node][0] = 0
min_order = [] min_order = []
for order in min_memory_generator(node_list, compute_map): for order in min_memory_generator(node_list):
post_thunk_old_storage = [] post_thunk_old_storage = []
computed, last_user = theano.gof.link.gc_helper(order) computed, last_user = theano.gof.link.gc_helper(order)
for node in order: for node in order:
...@@ -826,7 +825,6 @@ class ProfileStats(object): ...@@ -826,7 +825,6 @@ class ProfileStats(object):
del fgraph, nodes_mem, post_thunk_old_storage, node del fgraph, nodes_mem, post_thunk_old_storage, node
if len(fct_memory) > 1: if len(fct_memory) > 1:
print >> file, ("Memory Profile " print >> file, ("Memory Profile "
"(the max between all functions in that profile)") "(the max between all functions in that profile)")
...@@ -846,7 +844,7 @@ class ProfileStats(object): ...@@ -846,7 +844,7 @@ class ProfileStats(object):
print >> file, " Max if linker=cvm(default): %dKB (%dKB)" % (int(round( print >> file, " Max if linker=cvm(default): %dKB (%dKB)" % (int(round(
new_max_running_max_memory_size / 1024.)), int(round( new_max_running_max_memory_size / 1024.)), int(round(
max_running_max_memory_size / 1024.))) max_running_max_memory_size / 1024.)))
print >> file, " Minimum peak from all valid apply node order is %dKB" % int(round(minimum_peak / 1024.)) print >> file, " Minimum peak from all valid apply node order is %dKB" % int(round(max_minimum_peak / 1024.))
print >> file, " Memory saved if views are used: %dKB (%dKB)" % (int( print >> file, " Memory saved if views are used: %dKB (%dKB)" % (int(
round(new_max_node_memory_saved_by_view / 1024.)), int( round(new_max_node_memory_saved_by_view / 1024.)), int(
round(max_node_memory_saved_by_view / 1024.))) round(max_node_memory_saved_by_view / 1024.)))
...@@ -924,7 +922,6 @@ class ProfileStats(object): ...@@ -924,7 +922,6 @@ class ProfileStats(object):
" emitted in those cases.") " emitted in those cases.")
print >> file, '' print >> file, ''
print >> file, " The minimum peak from all valid apply node order is %dKB" % int(round(minimum_peak / 1024.))
def summary(self, file=sys.stderr, n_ops_to_print=20, def summary(self, file=sys.stderr, n_ops_to_print=20,
n_apply_to_print=20): n_apply_to_print=20):
......
...@@ -41,7 +41,9 @@ def test_profiling(): ...@@ -41,7 +41,9 @@ def test_profiling():
for line in buf.getvalue().split("\n"): for line in buf.getvalue().split("\n"):
if "Max if linker=cvm" in line: if "Max if linker=cvm" in line:
print line print line
elif "The minimum peak from all valid apply node" in line: elif "Minimum peak from all valid apply node" in line:
print line
elif "order" in line:
print line print line
finally: finally:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论