提交 6d5e4ff2 authored 作者: Roy Xue's avatar Roy Xue

Updates and fix algo

上级 ca78868f
...@@ -740,7 +740,7 @@ class ProfileStats(object): ...@@ -740,7 +740,7 @@ class ProfileStats(object):
Generate all valid node order from node_list Generate all valid node order from node_list
and compute its memory peaf and compute its memory peaf
""" """
global mem_count, mem_bound, max_mem_count, max_mem_count global mem_count, mem_bound, max_mem_count
for node in executable_nodes: for node in executable_nodes:
new_exec_nodes = executable_nodes.copy() new_exec_nodes = executable_nodes.copy()
...@@ -749,6 +749,8 @@ class ProfileStats(object): ...@@ -749,6 +749,8 @@ class ProfileStats(object):
mem_created = 0 mem_created = 0
mem_freed = 0 mem_freed = 0
max_storage = max_mem_count
# check if we cut path now # check if we cut path now
if max_mem_count > mem_bound: if max_mem_count > mem_bound:
continue continue
...@@ -776,7 +778,6 @@ class ProfileStats(object): ...@@ -776,7 +778,6 @@ class ProfileStats(object):
for val in node.inputs: for val in node.inputs:
if (dependencies[val] and val.owner and val not in fgraph.outputs): if (dependencies[val] and val.owner and val not in fgraph.outputs):
if all(compute_map[v] for v in dependencies[val]): if all(compute_map[v] for v in dependencies[val]):
if not dmap and not vmap:
mem_freed += var_mem[val] mem_freed += var_mem[val]
mem_count -= mem_freed mem_count -= mem_freed
...@@ -797,7 +798,7 @@ class ProfileStats(object): ...@@ -797,7 +798,7 @@ class ProfileStats(object):
# Reset track variables # Reset track variables
mem_count -= mem_created mem_count -= mem_created
max_mem_count -= mem_created max_mem_count = max_storage
mem_count += mem_freed mem_count += mem_freed
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 0 compute_map[var][0] = 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论