提交 503ef8db authored 作者: Roy Xue's avatar Roy Xue

Updates

1. Simulate execution step in generator. 2. Use new memory usage counting method
上级 a8668275
...@@ -692,7 +692,7 @@ class ProfileStats(object): ...@@ -692,7 +692,7 @@ class ProfileStats(object):
max_minimum_peak = 0 max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
global min_mem, current_mem global min_mem, current_mem, compute_map
mem_list = [] mem_list = []
order_index = 0 order_index = 0
order = [] order = []
...@@ -702,6 +702,9 @@ class ProfileStats(object): ...@@ -702,6 +702,9 @@ class ProfileStats(object):
compute_map = fgraph.profile.compute_map compute_map = fgraph.profile.compute_map
# compute_map use to check if a node is valid # compute_map use to check if a node is valid
for node in node_list:
for v in node.outputs:
compute_map[v][0] = 0
def check_node_state(node): def check_node_state(node):
""" """
...@@ -715,12 +718,12 @@ class ProfileStats(object): ...@@ -715,12 +718,12 @@ class ProfileStats(object):
computed_ins = all(compute_map[v][0] for v in deps) computed_ins = all(compute_map[v][0] for v in deps)
computed_outs = all(compute_map[v][0] for v in outputs) computed_outs = all(compute_map[v][0] for v in outputs)
# check if there could be a compute_map # check if there could be a compute_map
if computed_ins: if computed_ins and not computed_outs:
return True return True
else: else:
return False return False
def min_memory_generator(node_list, b=False): def min_memory_generator(node_list):
global current_mem, min_mem global current_mem, min_mem
''' '''
enumerate all valid order( node with inputs in its compute_map) enumerate all valid order( node with inputs in its compute_map)
...@@ -735,23 +738,29 @@ class ProfileStats(object): ...@@ -735,23 +738,29 @@ class ProfileStats(object):
if check_node_state(v[0]): if check_node_state(v[0]):
if len(node_list) == 1: if len(node_list) == 1:
yield v yield v
current_mem += sum(nodes_mem[v[0]]) for i in v[0].outputs:
b = True compute_map[i][0] = 1
# current_mem += sum(nodes_mem[v[0]])
else: else:
b = False
rest = node_list[ :i] + node_list[i+1: ] rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest): for p in min_memory_generator(rest):
yield v+p yield v+p
current_mem += sum(nodes_mem[v[0]]) for i in v[0].outputs:
if b: compute_map[i][0] = 1
if current_mem != 0: # current_mem += sum(nodes_mem[v[0]])
mem_list.append(current_mem) # we would use the count_running_memory to calculate the memory usage
if not min_mem:
min_mem = current_mem # if len(node_list) == 1:
if current_mem < min_mem: # if current_mem != 0:
min_mem = current_mem # mem_list.append(current_mem)
order_index = mem_list.index(current_mem) # if not min_mem:
current_mem = 0 # min_mem = current_mem
# # intial the min_mem with current_mem,
# # for this step, order_index = 0
# if current_mem < min_mem:
# min_mem = current_mem
# order_index = mem_list.index(current_mem)
# current_mem = 0
...@@ -760,6 +769,22 @@ class ProfileStats(object): ...@@ -760,6 +769,22 @@ class ProfileStats(object):
for i in min_memory_generator(node_list): for i in min_memory_generator(node_list):
temp.append(i) temp.append(i)
for order in temp:
post_thunk_old_storage = []
computed, last_user = theano.gof.link.gc_helper(order)
for node in order:
post_thunk_old_storage.append([
input_idx
for input_idx, input in enumerate(node.inputs)
if (input in computed) and
(input not in fgraph.outputs) and
node == last_user[input]])
current_mem = count_running_memory(order, post_thunk_old_storage, nodes_mem)[2]
if current_mem < min_mem:
min_mem = current_mem
order_index = temp.index(order)
order = temp[order_index] order = temp[order_index]
return order, min_mem return order, min_mem
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论