提交 ee876bb1 authored 作者: Roy Xue's avatar Roy Xue

Updates

1. use cpmputed_map as a argument for min_memory_generator. however which it will works remains doubt, the question I will email Fred, and Arnaud
上级 503ef8db
......@@ -692,7 +692,7 @@ class ProfileStats(object):
max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem):
global min_mem, current_mem, compute_map
global min_mem, current_mem
mem_list = []
order_index = 0
order = []
......@@ -701,6 +701,7 @@ class ProfileStats(object):
current_mem = 0
compute_map = fgraph.profile.compute_map
# compute_map use to check if a node is valid
for node in node_list:
for v in node.outputs:
......@@ -723,7 +724,7 @@ class ProfileStats(object):
else:
return False
def min_memory_generator(node_list):
def min_memory_generator(node_list, compute_map):
global current_mem, min_mem
'''
enumerate all valid order( node with inputs in its compute_map)
......@@ -731,6 +732,8 @@ class ProfileStats(object):
return an order with minimum memory usage
:param node_list: a list of apply nodes
:param compute_map: simulate the node execution steps to update compute_map
'''
for i in range(len(node_list)):
......@@ -740,33 +743,16 @@ class ProfileStats(object):
yield v
for i in v[0].outputs:
compute_map[i][0] = 1
# current_mem += sum(nodes_mem[v[0]])
else:
rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest):
yield v+p
for i in v[0].outputs:
compute_map[i][0] = 1
# current_mem += sum(nodes_mem[v[0]])
# we would use the count_running_memory to calculate the memory usage
# if len(node_list) == 1:
# if current_mem != 0:
# mem_list.append(current_mem)
# if not min_mem:
# min_mem = current_mem
# # intial the min_mem with current_mem,
# # for this step, order_index = 0
# if current_mem < min_mem:
# min_mem = current_mem
# order_index = mem_list.index(current_mem)
# current_mem = 0
temp = []
for i in min_memory_generator(node_list):
for i in min_memory_generator(node_list, compute_map):
temp.append(i)
for order in temp:
......@@ -784,7 +770,7 @@ class ProfileStats(object):
if current_mem < min_mem:
min_mem = current_mem
order_index = temp.index(order)
order = temp[order_index]
return order, min_mem
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论