提交 0db91eb0 authored 作者: Roy Xue's avatar Roy Xue

Code Style Fixes

上级 5ec8820f
...@@ -743,6 +743,8 @@ class ProfileStats(object): ...@@ -743,6 +743,8 @@ class ProfileStats(object):
""" """
Generate all valid node order from node_list Generate all valid node order from node_list
and compute its memory peaf and compute its memory peaf
:param executable_nodes: Set of executable nodes
""" """
global mem_count, mem_bound, max_mem_count global mem_count, mem_bound, max_mem_count
...@@ -750,6 +752,7 @@ class ProfileStats(object): ...@@ -750,6 +752,7 @@ class ProfileStats(object):
new_exec_nodes = executable_nodes.copy() new_exec_nodes = executable_nodes.copy()
new_exec_nodes.remove(node) new_exec_nodes.remove(node)
# Check if cut path now
if max_mem_count > mem_bound: if max_mem_count > mem_bound:
continue continue
...@@ -766,7 +769,6 @@ class ProfileStats(object): ...@@ -766,7 +769,6 @@ class ProfileStats(object):
viewed_by[i] = [] viewed_by[i] = []
# {var1: original var viewed by var1} # {var1: original var viewed by var1}
view_of = {} view_of = {}
dmap = getattr(node.op, 'destroy_map', None) dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None) vmap = getattr(node.op, 'view_map', None)
...@@ -818,6 +820,7 @@ class ProfileStats(object): ...@@ -818,6 +820,7 @@ class ProfileStats(object):
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 0 compute_map[var][0] = 0
# Loop all valid orders and find min peak(store in mem_bound)
for order in min_memory_generator(executable_nodes): for order in min_memory_generator(executable_nodes):
continue continue
...@@ -871,6 +874,7 @@ class ProfileStats(object): ...@@ -871,6 +874,7 @@ class ProfileStats(object):
new_max_node_memory_saved_by_inplace = max( new_max_node_memory_saved_by_inplace = max(
new_max_node_memory_saved_by_inplace, new_running_memory[3]) new_max_node_memory_saved_by_inplace, new_running_memory[3])
# Config: whether print min memory peak
if config.profiling.min_peak_memory: if config.profiling.min_peak_memory:
node_list = fgraph.apply_nodes node_list = fgraph.apply_nodes
min_peak = count_minimum_peak(node_list, fgraph, nodes_mem) min_peak = count_minimum_peak(node_list, fgraph, nodes_mem)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论