提交 5ec8820f authored 作者: Roy Xue's avatar Roy Xue

Algo updates

Runtime: 2-3 mins
上级 965e5335
...@@ -749,6 +749,10 @@ class ProfileStats(object): ...@@ -749,6 +749,10 @@ class ProfileStats(object):
for node in executable_nodes: for node in executable_nodes:
new_exec_nodes = executable_nodes.copy() new_exec_nodes = executable_nodes.copy()
new_exec_nodes.remove(node) new_exec_nodes.remove(node)
if max_mem_count > mem_bound:
continue
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 1 compute_map[var][0] = 1
...@@ -758,13 +762,11 @@ class ProfileStats(object): ...@@ -758,13 +762,11 @@ class ProfileStats(object):
# {var1:[vars that view var1]} # {var1:[vars that view var1]}
viewed_by = {} viewed_by = {}
for i in node.inputs:
viewed_by[i] = []
# {var1: original var viewed by var1} # {var1: original var viewed by var1}
view_of = {} view_of = {}
# check if we cut path now
if max_mem_count > mem_bound:
continue
dmap = getattr(node.op, 'destroy_map', None) dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None) vmap = getattr(node.op, 'view_map', None)
...@@ -772,10 +774,7 @@ class ProfileStats(object): ...@@ -772,10 +774,7 @@ class ProfileStats(object):
if (dmap or vmap): if (dmap or vmap):
for i in node.inputs: for i in node.inputs:
view_of[out] = view_of.get(i, i) view_of[out] = view_of.get(i, i)
if i in viewed_by: viewed_by[i].append(out)
viewed_by[i].append[out]
else:
viewed_by[i] = [out]
else: else:
mem_created += var_mem[out] mem_created += var_mem[out]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论