提交 2aca2e14 authored 作者: Roy Xue's avatar Roy Xue

updates

1. remove global variables 2. modify discription 3. use new way to track order
上级 c84d0261
...@@ -692,7 +692,6 @@ class ProfileStats(object): ...@@ -692,7 +692,6 @@ class ProfileStats(object):
max_minimum_peak = 0 max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
global min_mem, current_mem
mem_list = [] mem_list = []
order_index = 0 order_index = 0
order = [] order = []
...@@ -725,12 +724,11 @@ class ProfileStats(object): ...@@ -725,12 +724,11 @@ class ProfileStats(object):
return False return False
def min_memory_generator(node_list, compute_map): def min_memory_generator(node_list, compute_map):
global current_mem, min_mem
''' '''
enumerate all valid order( node with inputs in its compute_map) enumerate all valid orders for the list of nodes in node_list
compute the peak of all order and keep the order with the minimum peak. compute the peak of all order and keep the order with the minimum peak.
return an order with minimum memory usage return minimum memory usage
:param node_list: a list of apply nodes :param node_list: a list of apply nodes
:param compute_map: simulate the node execution steps to update compute_map :param compute_map: simulate the node execution steps to update compute_map
...@@ -752,6 +750,7 @@ class ProfileStats(object): ...@@ -752,6 +750,7 @@ class ProfileStats(object):
temp = [] temp = []
min_order = []
for order in min_memory_generator(node_list, compute_map): for order in min_memory_generator(node_list, compute_map):
temp.append(order) temp.append(order)
...@@ -767,14 +766,13 @@ class ProfileStats(object): ...@@ -767,14 +766,13 @@ class ProfileStats(object):
(input not in fgraph.outputs) and (input not in fgraph.outputs) and
node == last_user[input]]) node == last_user[input]])
current_mem = count_running_memory(order, post_thunk_old_storage, nodes_mem)[2] current_mem = count_running_memory(order, post_thunk_old_storage, nodes_mem)[2]
current_order = order
if current_mem < min_mem: if current_mem < min_mem:
min_mem = current_mem min_mem = current_mem
order_index = temp.index(order) min_order = current_order
order = temp[order_index]
return order, min_mem return min_order, min_mem
for fgraph, nodes_mem in fct_memory.iteritems(): for fgraph, nodes_mem in fct_memory.iteritems():
# Sum of the size of all variables in bytes # Sum of the size of all variables in bytes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论