提交 0a97ec6e authored 作者: Roy Xue's avatar Roy Xue

fix bugs

1. make global variables correctly 2. use fgraph.apply_nodes replace the previous fgraph.nodes
上级 ccf91498
...@@ -101,6 +101,9 @@ def _atexit_print_fn(): ...@@ -101,6 +101,9 @@ def _atexit_print_fn():
atexit.register(_atexit_print_fn) atexit.register(_atexit_print_fn)
current_mem = 0
min_mem = 0
# global variables used to store memory usage in generator
class ProfileStats(object): class ProfileStats(object):
""" """
...@@ -691,10 +694,6 @@ class ProfileStats(object): ...@@ -691,10 +694,6 @@ class ProfileStats(object):
return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view] return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view]
current_mem = 0
min_mem = 0
# varaibles used in count_minimum_peak
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
mem_list = [] mem_list = []
order_index = 0 order_index = 0
...@@ -722,7 +721,7 @@ class ProfileStats(object): ...@@ -722,7 +721,7 @@ class ProfileStats(object):
return False return False
def min_memory_generator(node_list, b=False): def min_memory_generator(node_list, b=False):
global mem_list, current_mem, order_index, min_mem global current_mem, min_mem
''' '''
enumerate all valid order( node with inputs in its compute_map) enumerate all valid order( node with inputs in its compute_map)
compute the peak of all order and keep the order with the minimum peak. compute the peak of all order and keep the order with the minimum peak.
...@@ -736,14 +735,14 @@ class ProfileStats(object): ...@@ -736,14 +735,14 @@ class ProfileStats(object):
if check_node_state(v[0]): if check_node_state(v[0]):
if len(node_list) == 1: if len(node_list) == 1:
yield v yield v
current_mem += nodes_mem[v[0]] current_mem += sum(nodes_mem[v[0]])
b = True b = True
else: else:
b = False b = False
rest = node_list[ :i] + node_list[i+1: ] rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest): for p in min_memory_generator(rest):
yield v+p yield v+p
current_mem += nodes_mem[v[0]] current_mem += sum(nodes_mem[v[0]])
if b: if b:
if current_mem != 0: if current_mem != 0:
mem_list.append(current_mem) mem_list.append(current_mem)
...@@ -814,7 +813,7 @@ class ProfileStats(object): ...@@ -814,7 +813,7 @@ class ProfileStats(object):
new_max_node_memory_saved_by_inplace, new_running_memory[3]) new_max_node_memory_saved_by_inplace, new_running_memory[3])
node_list = fgraph.nodes node_list = fgraph.apply_nodes
_, minimum_peak = count_minimum_peak(node_list, fgraph, nodes_mem) _, minimum_peak = count_minimum_peak(node_list, fgraph, nodes_mem)
# for the best order, we dont use it now # for the best order, we dont use it now
max_minimum_peak = max(max_minimum_peak, minimum_peak) max_minimum_peak = max(max_minimum_peak, minimum_peak)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论