提交 99696ef9 authored 作者: Roy Xue's avatar Roy Xue

change to dependencies solution

上级 7c7a5f6d
...@@ -694,19 +694,13 @@ class ProfileStats(object): ...@@ -694,19 +694,13 @@ class ProfileStats(object):
max_minimum_peak = 0 max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
global maybe_executed, mem_count, mem_bound, last_user, freed global maybe_executed, mem_count, mem_bound
order = [] order = []
min_order = [] min_order = []
node_list = list(node_list) node_list = list(node_list)
min_mem = sys.maxint
current_mem = 0 current_mem = 0
mem_count = 0 mem_count = 0
mem_bound = numpy.inf mem_bound = numpy.inf
computed = set()
last_user = {}
for node in node_list:
for ins in node.inputs:
last_user[ins] = 0
def check_node_state(node): def check_node_state(node):
""" """
...@@ -738,14 +732,18 @@ class ProfileStats(object): ...@@ -738,14 +732,18 @@ class ProfileStats(object):
if c != "output" and check_node_state(c): if c != "output" and check_node_state(c):
executables_nodes.add(c) executables_nodes.add(c)
for node in node_list: dependencies = fgraph.profile.dependencies
for ins in node.inputs: # for node in node_list[0].inputs[0]:
last_user[ins] += 1 # dependencies[node] = []
for outs in node.outputs: # if val.owner and val.clients:
computed.add(outs) # ls = []
# for c in val.clients:
# if c[0] is not 'output':
# ls += c[0].outputs
# dependencies[val] += ls
def min_memory_generator(executables_nodes): def min_memory_generator(executables_nodes):
global mem_count, mem_bound, last_user, freed global mem_count, mem_bound
for node in executables_nodes: for node in executables_nodes:
new_exec_nodes = executables_nodes.copy() new_exec_nodes = executables_nodes.copy()
new_exec_nodes.remove(node) new_exec_nodes.remove(node)
...@@ -759,8 +757,6 @@ class ProfileStats(object): ...@@ -759,8 +757,6 @@ class ProfileStats(object):
idx = 0 idx = 0
dmap = getattr(node.op, 'destroy_map', None) dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None) vmap = getattr(node.op, 'view_map', None)
for ins in node.inputs:
last_user[ins] -= 1
# add mem_create # add mem_create
for i in nodes_mem[node]: for i in nodes_mem[node]:
...@@ -771,10 +767,11 @@ class ProfileStats(object): ...@@ -771,10 +767,11 @@ class ProfileStats(object):
idx += 1 idx += 1
#add mem_freed, this part is not working well #add mem_freed, this part is not working well
for ins in node.inputs: print type(node)
if (ins in computed) and (ins not in fgraph.outputs) and (last_user[ins] == 0): for val in node.inputs:
if not isinstance(var_mem[ins], str): if (dependencies[val] and val.owner and val not in fgraph.outputs):
mem_freed += var_mem[ins] if all(compute_map[v] for v in dependencies[val]):
mem_freed += var_mem[val]
mem_count += mem_created mem_count += mem_created
mem_count -= mem_freed mem_count -= mem_freed
...@@ -783,8 +780,6 @@ class ProfileStats(object): ...@@ -783,8 +780,6 @@ class ProfileStats(object):
if mem_count > mem_bound: if mem_count > mem_bound:
mem_count -= mem_created mem_count -= mem_created
mem_count += mem_freed mem_count += mem_freed
for ins in node.inputs:
last_user[ins] += 1
continue continue
for var in node.outputs: for var in node.outputs:
...@@ -804,8 +799,6 @@ class ProfileStats(object): ...@@ -804,8 +799,6 @@ class ProfileStats(object):
# resetting part # resetting part
mem_count -= mem_created mem_count -= mem_created
mem_count += mem_freed mem_count += mem_freed
for ins in node.inputs:
last_user[ins] += 1
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 0 compute_map[var][0] = 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论