提交 441f01cf authored 作者: Roy Xue's avatar Roy Xue

updates

上级 5fb8f003
......@@ -697,14 +697,14 @@ class ProfileStats(object):
order_index = 0
order = []
node_list = list(node_list)
min_mem = 0
min_mem = sys.maxint
current_mem = 0
compute_map = defaultdict(lambda: [0])
# compute_map use to check if a node is valid
for node in fgraph.inputs:
compute_map[node][0] = 1
for node in node_list:
for val in node.inputs:
compute_map[val][0] = 1
def check_node_state(node):
"""
......@@ -715,15 +715,15 @@ class ProfileStats(object):
inputs = node.inputs
outputs = node.outputs
deps = inputs + node.destroy_dependencies
computed_ins = all(compute_map[v][0] for v in deps)
computed_ins = all(compute_map[v][0] for v in inputs)
computed_outs = all(compute_map[v][0] for v in outputs)
# check if there could be a compute_map
if computed_ins and not computed_outs:
if computed_ins and not computed_outs:
return True
else:
return False
def min_memory_generator(node_list, compute_map):
def min_memory_generator(node_list):
'''
enumerate all valid orders for the list of nodes in node_list
compute the peak of all order and keep the order with the minimum peak.
......@@ -733,7 +733,6 @@ class ProfileStats(object):
:param compute_map: simulate the node execution steps to update compute_map
'''
for i in range(len(node_list)):
v = node_list[i:i+1]
if check_node_state(v[0]):
......@@ -743,14 +742,14 @@ class ProfileStats(object):
yield v
else:
rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest, compute_map):
for p in min_memory_generator(rest):
yield v+p
for node in v[0].outputs:
compute_map[node][0] = 0
for node in v[0].outputs:
compute_map[node][0] = 0
min_order = []
for order in min_memory_generator(node_list, compute_map):
for order in min_memory_generator(node_list):
post_thunk_old_storage = []
computed, last_user = theano.gof.link.gc_helper(order)
for node in order:
......@@ -825,7 +824,6 @@ class ProfileStats(object):
del fgraph, nodes_mem, post_thunk_old_storage, node
if len(fct_memory) > 1:
print >> file, ("Memory Profile "
......@@ -846,7 +844,7 @@ class ProfileStats(object):
print >> file, " Max if linker=cvm(default): %dKB (%dKB)" % (int(round(
new_max_running_max_memory_size / 1024.)), int(round(
max_running_max_memory_size / 1024.)))
print >> file, " Minimum peak from all valid apply node order is %dKB" % int(round(minimum_peak / 1024.))
print >> file, " Minimum peak from all valid apply node order is %dKB" % int(round(max_minimum_peak / 1024.))
print >> file, " Memory saved if views are used: %dKB (%dKB)" % (int(
round(new_max_node_memory_saved_by_view / 1024.)), int(
round(max_node_memory_saved_by_view / 1024.)))
......@@ -924,7 +922,6 @@ class ProfileStats(object):
" emitted in those cases.")
print >> file, ''
print >> file, " The minimum peak from all valid apply node order is %dKB" % int(round(minimum_peak / 1024.))
def summary(self, file=sys.stderr, n_ops_to_print=20,
n_apply_to_print=20):
......
......@@ -41,7 +41,9 @@ def test_profiling():
for line in buf.getvalue().split("\n"):
if "Max if linker=cvm" in line:
print line
elif "The minimum peak from all valid apply node" in line:
elif "Minimum peak from all valid apply node" in line:
print line
elif "order" in line:
print line
finally:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论