提交 42e6a77b authored 作者: Roy Xue's avatar Roy Xue

fixes

上级 6879fb89
...@@ -661,6 +661,7 @@ class ProfileStats(object): ...@@ -661,6 +661,7 @@ class ProfileStats(object):
running_max_memory_size = 0 running_max_memory_size = 0
node_memory_saved_by_view = 0 node_memory_saved_by_view = 0
node_memory_saved_by_inplace = 0 node_memory_saved_by_inplace = 0
node_idx = 0
for node in order: for node in order:
val = nodes_mem[node] val = nodes_mem[node]
...@@ -680,11 +681,12 @@ class ProfileStats(object): ...@@ -680,11 +681,12 @@ class ProfileStats(object):
idx += 1 idx += 1
if running_memory_size > running_max_memory_size: if running_memory_size > running_max_memory_size:
running_max_memory_size = running_memory_size running_max_memory_size = running_memory_size
old_storage = thunk_old_storage[order.index(node)] old_storage = thunk_old_storage[node_idx]
for old_s in old_storage: for old_s in old_storage:
old_v = var_mem[node.inputs[old_s]] old_v = var_mem[node.inputs[old_s]]
if not isinstance(old_v, str): if not isinstance(old_v, str):
running_memory_size -= old_v running_memory_size -= old_v
node_idx += 1
return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view] return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view]
...@@ -716,9 +718,7 @@ class ProfileStats(object): ...@@ -716,9 +718,7 @@ class ProfileStats(object):
if isinstance(node, graph.Constant): if isinstance(node, graph.Constant):
compute_map[node][0] = 1 compute_map[node][0] = 1
computed_ins = all(compute_map[v][0] for v in inputs) computed_ins = all(compute_map[v][0] for v in inputs)
computed_outs = all(compute_map[v][0] for v in outputs) if computed_ins:
# check if there could be a compute_map
if computed_ins and not computed_outs:
return True return True
else: else:
return False return False
...@@ -755,25 +755,31 @@ class ProfileStats(object): ...@@ -755,25 +755,31 @@ class ProfileStats(object):
def count_min_memory(order, thunk_old_storage, nodes_mem): def count_min_memory(order, thunk_old_storage, nodes_mem):
running_memory_size = 0 running_memory_size = 0
running_max_memory_size = 0 running_max_memory_size = 0
node_idx = 0
for node in order: for node in order:
val = nodes_mem[node] val = nodes_mem[node]
dmap = getattr(node.op, 'destroy_map', None) dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None) vmap = getattr(node.op, 'view_map', None)
idx = o idx = 0
for v in val: for v in val:
# TODO check the op returned a view # TODO check the op returned a view
if idx not in dmap and idx not in vmap and not isinstance(v, str): if dmap and idx in dmap:
continue
elif vmap and idx in vmap:
continue
elif not isinstance(v, str):
running_memory_size += v running_memory_size += v
idx += 1 idx += 1
if running_memory_size > running_max_memory_size: if running_memory_size > running_max_memory_size:
running_max_memory_size = running_memory_size running_max_memory_size = running_memory_size
old_storage = thunk_old_storage[order.index(node)] old_storage = thunk_old_storage[node_idx]
for old_s in old_storage: for old_s in old_storage:
old_v = var_mem[node.inputs[old_s]] old_v = var_mem[node.inputs[old_s]]
if not isinstance(old_v, str): if not isinstance(old_v, str):
running_memory_size -= old_v running_memory_size -= old_v
node_idx += 1
return running_max_memory_size return running_max_memory_size
...@@ -787,7 +793,7 @@ class ProfileStats(object): ...@@ -787,7 +793,7 @@ class ProfileStats(object):
if (input in computed) and if (input in computed) and
(input not in fgraph.outputs) and (input not in fgraph.outputs) and
node == last_user[input]]) node == last_user[input]])
current_mem = count_running_memory(order, post_thunk_old_storage, nodes_mem)[2] current_mem = count_min_memory(order, post_thunk_old_storage, nodes_mem)
if current_mem < min_mem: if current_mem < min_mem:
min_mem = current_mem min_mem = current_mem
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论