提交 a97ad2b6 authored 作者: Roy Xue's avatar Roy Xue

working version:)

上级 1c2f8c8d
......@@ -694,12 +694,13 @@ class ProfileStats(object):
max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem):
global maybe_executed, mem_count, mem_bound
global maybe_executed, mem_count, mem_bound, max_mem_count
order = []
min_order = []
node_list = list(node_list)
current_mem = 0
mem_count = 0
max_mem_count = 0
mem_bound = numpy.inf
def check_node_state(node):
......@@ -743,7 +744,7 @@ class ProfileStats(object):
# dependencies[val] += ls
def min_memory_generator(executables_nodes):
global mem_count, mem_bound
global mem_count, mem_bound, max_mem_count, max_mem_count
for node in executables_nodes:
new_exec_nodes = executables_nodes.copy()
new_exec_nodes.remove(node)
......@@ -771,6 +772,10 @@ class ProfileStats(object):
mem_created += i
idx += 1
mem_count += mem_created
if mem_count > max_mem_count:
max_mem_count = mem_count
#add mem_freed, this part is not working well
for val in node.inputs:
if (dependencies[val] and val.owner and val not in fgraph.outputs):
......@@ -778,12 +783,10 @@ class ProfileStats(object):
if not dmap and not vmap:
mem_freed += var_mem[val]
mem_count += mem_created
mem_count -= mem_freed
# check if cut path now
for var in node.outputs:
for c, _ in var.clients:
if c != "output" and check_node_state(c):
......@@ -792,14 +795,15 @@ class ProfileStats(object):
if not new_exec_nodes:
yield [node]
#update mem_bound
if mem_count < mem_bound:
mem_bound = mem_count
if max_mem_count < mem_bound:
mem_bound = max_mem_count
else:
for p in min_memory_generator(new_exec_nodes):
yield [node]+p
# resetting part
mem_count -= mem_created
max_mem_count -= mem_created
mem_count += mem_freed
for var in node.outputs:
compute_map[var][0] = 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论