提交 695db885 authored 作者: Roy Xue's avatar Roy Xue

maybe a better way but still need to modify and test

上级 9515ee54
......@@ -740,6 +740,13 @@ class ProfileStats(object):
dependencies = fgraph.profile.dependencies
print dependencies
compute = set()
last_user = {}
for node in node_list:
for ins in node.inputs:
last_user[ins] = node
for outs in node.outputs:
computed.add(outs)
# for node in node_list[0].inputs[0]:
# dependencies[node] = []
# if val.owner and val.clients:
......@@ -755,19 +762,34 @@ class ProfileStats(object):
new_exec_nodes = executables_nodes.copy()
new_exec_nodes.remove(node)
mem_created = 0
idx = 0
dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None)
mem_created = sum(nodes_mem[node])
# for i in nodes_mem[node]:
# if (dmap and idx in dmap) or (vmap and idx in vmap):
# continue
# elif not isinstance(i, str):
# mem_created += i
# idx += 1
mem_count += mem_created
# Add memory created
# dependencies = {}
mem_freed = 0
for ins in node.inputs:
if (ins in computed) and (ins not in fgraph.outputs) and (node == last_user[ins]):
mem_freed += var_mem[ins]
# for val in node.outputs:
# if (dependencies[val] and val.owner and val not in fgraph.outputs):
# # print dependencies[val]
# # print compute_map
# for v in dependencies[val]:
# if compute_map[v] == 1:
# print "yes"
# mem_freed += var_mem[val]
for val in node.inputs:
if (dependencies[val] and val.owner and val not in fgraph.outputs):
if all(compute_map[v][0] for v in dependencies[val]):
mem_freed += var_mem[val]
# print mem_freed
mem_count -= mem_freed
# Reduce memory freed
if mem_count > mem_bound:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论