提交 23022cd8 authored 作者: Roy Xue's avatar Roy Xue

New algo with count running memory method

上级 22c94252
...@@ -645,7 +645,7 @@ class ProfileStats(object): ...@@ -645,7 +645,7 @@ class ProfileStats(object):
# track min peak memory usage # track min peak memory usage
max_min_peak = 0 max_min_peak = 0
def count_running_memory(order, thunk_old_storage, nodes_mem): def count_running_memory(order, fgraph, nodes_mem):
""" """
Calculate memory with specific node order Calculate memory with specific node order
Return a list including the following values Return a list including the following values
...@@ -662,38 +662,53 @@ class ProfileStats(object): ...@@ -662,38 +662,53 @@ class ProfileStats(object):
5. node_memory_saved_by_inplace 5. node_memory_saved_by_inplace
The sum of memory saved by reusing the input instead of The sum of memory saved by reusing the input instead of
new allocation new allocation
""" """
node_memory_size = 0 node_memory_size = 0
running_memory_size = 0 running_memory_size = 0
running_max_memory_size = 0 running_max_memory_size = 0
node_memory_saved_by_view = 0 node_memory_saved_by_view = 0
node_memory_saved_by_inplace = 0 node_memory_saved_by_inplace = 0
node_idx = 0 dependencies = fgraph.profile.dependencies
for node in order: for node in order:
val = nodes_mem[node] idx = 0
dmap = getattr(node.op, 'destroy_map', None) dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None) vmap = getattr(node.op, 'view_map', None)
idx = 0 val = nodes_mem[node]
viewed_by = {}
for i in node.inputs:
viewed_by[i] = []
view_of = {}
for v in val: for v in val:
# TODO check the op returned a view
if dmap and idx in dmap: if dmap and idx in dmap:
node_memory_saved_by_inplace += v node_memory_saved_by_inplace += v
# TODO check the op returned a view
elif vmap and idx in vmap: elif vmap and idx in vmap:
node_memory_saved_by_view += v node_memory_saved_by_view += v
elif not isinstance(v, str): idx = 1
node_memory_size += v
running_memory_size += v for out in node.outputs:
idx += 1 if (dmap or vmap):
if running_memory_size > running_max_memory_size: for i in node.inputs:
running_max_memory_size = running_memory_size view_of[out] = view_of.get(i, i)
old_storage = thunk_old_storage[node_idx] viewed_by[i].append(out)
for old_s in old_storage: else:
old_v = var_mem[node.inputs[old_s]] running_memory_size += var_mem[out]
if not isinstance(old_v, str): node_memory_size += var_mem[out]
running_memory_size -= old_v
node_idx += 1 running_max_memory_size = max(running_max_memory_size, running_memory_size)
for ins in node.inputs:
assert not (ins in view_of and i in viewed_by)
if dependencies[ins] and ins not in fgraph.outputs:
if ins not in view_of and not viewed_by.get(ins, []):
running_memory_size -= var_mem[ins]
elif ins in view_of:
origin = view_of[ins]
viewed_by[origin].remove(ins)
if not viewed_by[origin]:
running_memory_size -= var_mem[origin]
return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view] return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view]
...@@ -833,24 +848,15 @@ class ProfileStats(object): ...@@ -833,24 +848,15 @@ class ProfileStats(object):
# after the execution of the corresponding node. # after the execution of the corresponding node.
# It mean that after executing the node, # It mean that after executing the node,
# the corresponding variable can be gc. # the corresponding variable can be gc.
post_thunk_old_storage = []
computed, last_user = theano.gof.link.gc_helper(order)
for node in order:
post_thunk_old_storage.append([
input_idx
for input_idx, input in enumerate(node.inputs)
if (input in computed) and
(input not in fgraph.outputs) and
node == last_user[input]])
old_running_memory = count_running_memory(order, post_thunk_old_storage, nodes_mem) old_running_memory = count_running_memory(order, fgraph, nodes_mem)
new_order = fgraph.profile.node_executed_order new_order = fgraph.profile.node_executed_order
# A list of new executed node order # A list of new executed node order
new_storage = fgraph.profile.node_cleared_order
# A list of variables that get freed
new_running_memory = count_running_memory(new_order, new_storage, nodes_mem) new_running_memory = count_running_memory(new_order, fgraph, nodes_mem)
print old_running_memory
# Store the max of some stats by any function in this profile. # Store the max of some stats by any function in this profile.
max_sum_size = max(max_sum_size, sum_size) max_sum_size = max(max_sum_size, sum_size)
...@@ -877,8 +883,7 @@ class ProfileStats(object): ...@@ -877,8 +883,7 @@ class ProfileStats(object):
min_peak = count_minimum_peak(node_list, fgraph, nodes_mem) min_peak = count_minimum_peak(node_list, fgraph, nodes_mem)
max_min_peak = max(max_min_peak, min_peak) max_min_peak = max(max_min_peak, min_peak)
del fgraph, nodes_mem, post_thunk_old_storage, node del fgraph, nodes_mem, node
if len(fct_memory) > 1: if len(fct_memory) > 1:
print >> file, ("Memory Profile " print >> file, ("Memory Profile "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论