提交 6b0a7569 authored 作者: Roy Xue's avatar Roy Xue

Merge pull request #5 from nouiz/GSoC2014_part2

G so c2014 part2
...@@ -671,8 +671,13 @@ class ProfileStats(object): ...@@ -671,8 +671,13 @@ class ProfileStats(object):
node_memory_saved_by_inplace = 0 node_memory_saved_by_inplace = 0
dependencies = fgraph.profile.dependencies dependencies = fgraph.profile.dependencies
# Initial compute_map which is used to check if a node is valid
compute_map = defaultdict(lambda: [0])
for var in fgraph.inputs:
compute_map[var][0] = 1
# two data structure used to mimic Python gc # two data structure used to mimic Python gc
viewed_by = {}# {var1: [vars that view var1]} viewed_by = {} # {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value. # The len of the list is the value of python ref count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo # This is more safe to help detect potential bug in the algo
for var in fgraph.variables: for var in fgraph.variables:
...@@ -681,6 +686,8 @@ class ProfileStats(object): ...@@ -681,6 +686,8 @@ class ProfileStats(object):
# The orignal mean that we don't keep trac of all the intermediate relationship in the view. # The orignal mean that we don't keep trac of all the intermediate relationship in the view.
for node in order: for node in order:
for var in node.outputs:
compute_map[var][0] = 1
idx = 0 idx = 0
dmap = getattr(node.op, 'destroy_map', None) dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None) vmap = getattr(node.op, 'view_map', None)
...@@ -695,40 +702,63 @@ class ProfileStats(object): ...@@ -695,40 +702,63 @@ class ProfileStats(object):
node_memory_saved_by_view += v node_memory_saved_by_view += v
idx += 1 idx += 1
# Update the Python emulating dicts and add the memory allocated by the node # Update the Python emulating dicts and add the memory
# allocated by the node
idx2 = 0 idx2 = 0
for out in node.outputs: for out in node.outputs:
if (dmap and idx2 in dmap) or (vmap and idx2 in vmap): ins = None
# This is needed for destroy_map in case it return a partial view that is destroyed. if dmap and idx2 in dmap:
# So the output could be different then the input. vidx = dmap[idx2]
for ins in node.inputs: assert len(vidx) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[vidx[0]]
if vmap and idx2 in vmap:
assert ins is None
vidx = vmap[idx2]
assert len(vidx) == 1, "Here we only support the possibility to view one input"
ins = node.inputs[vidx[0]]
if ins is not None:
# This is needed for destroy_map in case it
# return a partial view that is destroyed. So
# the output could be different then the
# input.
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
view_of[out] = view_of.get(ins, ins)# This get make that we keep trac of view only again the original # we keep trac of view only again the origin
viewed_by[ins].append(out) origin = view_of.get(ins, ins)
view_of[out] = origin
viewed_by[origin].append(out)
else: else:
running_memory_size += var_mem[out] running_memory_size += var_mem[out]
node_memory_size += var_mem[out] node_memory_size += var_mem[out]
idx2 += 1 idx2 += 1
running_max_memory_size = max(running_max_memory_size, running_memory_size) running_max_memory_size = max(running_max_memory_size,
running_memory_size)
# Mimic the combination of Theano and Python gc # Mimic the combination of Theano and Python gc
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins]) assert not (ins in view_of and viewed_by[ins])
# we keep trac of the original var, so this shouldn't happen # we trac the original var, so this shouldn't happen
if dependencies[ins] and ins not in fgraph.outputs and ins.owner: if (dependencies[ins] and
ins not in fgraph.outputs and
ins.owner and
all([compute_map[v][0] for v in dependencies[ins]])):
if ins not in view_of and not viewed_by.get(ins, []): if ins not in view_of and not viewed_by.get(ins, []):
running_memory_size -= var_mem[ins] running_memory_size -= var_mem[ins]
elif ins in view_of: elif ins in view_of:
origin = view_of[ins] origin = view_of[ins]
viewed_by[origin].remove(ins) viewed_by[origin].remove(ins)
if not viewed_by[origin] and origin not in fgraph.inputs: if (not viewed_by[origin] and
origin not in fgraph.inputs):
running_memory_size -= var_mem[origin] running_memory_size -= var_mem[origin]
else: else:
# ins is viewed_by something else, so its memory isn't freed # ins is viewed_by something else, so its
# memory isn't freed
pass pass
return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view] return [node_memory_size, running_memory_size,
running_max_memory_size, node_memory_saved_by_inplace,
node_memory_saved_by_view]
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
global mem_count, mem_bound, max_mem_count global mem_count, mem_bound, max_mem_count
...@@ -770,15 +800,6 @@ class ProfileStats(object): ...@@ -770,15 +800,6 @@ class ProfileStats(object):
if c != "output" and check_node_state(c): if c != "output" and check_node_state(c):
executable_nodes.add(c) executable_nodes.add(c)
# two data structure used to mimic Python gc
viewed_by = {}# {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo
for var in fgraph.variables:
viewed_by[var] = []
view_of = {}# {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view.
def min_memory_generator(executable_nodes, viewed_by, view_of): def min_memory_generator(executable_nodes, viewed_by, view_of):
""" """
Generate all valid node order from node_list Generate all valid node order from node_list
...@@ -796,8 +817,13 @@ class ProfileStats(object): ...@@ -796,8 +817,13 @@ class ProfileStats(object):
if max_mem_count > mem_bound: if max_mem_count > mem_bound:
continue continue
viewed_by_temp = viewed_by.copy()
view_of_temp = view_of.copy() view_of_temp = view_of.copy()
# We don't want a shallow copy, but we don't want
# a deep copy. So this do a "middle" copy, where
# we copy the dict and the list, but not the var
viewed_by_temp = {}
for k, v in viewed_by.iteritems():
viewed_by_temp[k] = list(v)
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 1 compute_map[var][0] = 1
...@@ -810,15 +836,29 @@ class ProfileStats(object): ...@@ -810,15 +836,29 @@ class ProfileStats(object):
vmap = getattr(node.op, 'view_map', None) vmap = getattr(node.op, 'view_map', None)
idx = 0 idx = 0
# Update the Python emulating dicts and add the memory allocated by the node # Update the Python emulating dicts and add the
# memory allocated by the node
for out in node.outputs: for out in node.outputs:
if (dmap and idx in dmap) or (vmap and idx in vmap): ins = None
# This is needed for destroy_map in case it return a partial view that is destroyed. if dmap and idx in dmap:
# So the output could be different then the input. vidx = dmap[idx]
for ins in node.inputs: assert len(vidx) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[vidx[0]]
if vmap and idx in vmap:
assert ins is None, "Here we only support the possibility to view one input"
vidx = vmap[idx]
assert len(vidx) == 1
ins = node.inputs[vidx[0]]
if ins is not None:
# This is needed for destroy_map in case it
# return a partial view that is destroyed. So
# the output could be different then the
# input.
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
view_of_temp[out] = view_of_temp.get(ins, ins)# This get make that we keep trac of view only again the original # We keep trac of view only again the original
viewed_by_temp[ins].append(out) origin = view_of_temp.get(ins, ins)
view_of_temp[out] = origin
viewed_by_temp[origin].append(out)
else: else:
mem_created += var_mem[out] mem_created += var_mem[out]
idx += 1 idx += 1
...@@ -828,10 +868,13 @@ class ProfileStats(object): ...@@ -828,10 +868,13 @@ class ProfileStats(object):
# Mimic the combination of Theano and Python gc. # Mimic the combination of Theano and Python gc.
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of_temp and viewed_by_temp[ins]) assert not (ins in view_of_temp and
# we keep track of the original var, so this shouldn't happen viewed_by_temp[ins])
if dependencies[ins] and ins not in fgraph.outputs and ins.owner: # We track of the original var, so this shouldn't happen
if all(compute_map[v] for v in dependencies[ins]): if (dependencies[ins] and
ins not in fgraph.outputs and
ins.owner and
all([compute_map[v][0] for v in dependencies[ins]])):
if ins not in view_of_temp and not viewed_by_temp.get(ins, []): if ins not in view_of_temp and not viewed_by_temp.get(ins, []):
mem_freed += var_mem[ins] mem_freed += var_mem[ins]
elif ins in view_of_temp: elif ins in view_of_temp:
...@@ -840,7 +883,8 @@ class ProfileStats(object): ...@@ -840,7 +883,8 @@ class ProfileStats(object):
if not viewed_by_temp[origin] and origin not in fgraph.inputs: if not viewed_by_temp[origin] and origin not in fgraph.inputs:
mem_freed += var_mem[origin] mem_freed += var_mem[origin]
else: else:
# ins is viewed_by something else, so its memory isn't freed # ins is viewed_by something else, so its
# memory isn't freed
pass pass
mem_count -= mem_freed mem_count -= mem_freed
...@@ -852,11 +896,13 @@ class ProfileStats(object): ...@@ -852,11 +896,13 @@ class ProfileStats(object):
if not new_exec_nodes: if not new_exec_nodes:
yield [node] yield [node]
#Check and Update mem_bound # Check and Update mem_bound
if max_mem_count < mem_bound: if max_mem_count < mem_bound:
mem_bound = max_mem_count mem_bound = max_mem_count
else: else:
for p in min_memory_generator(new_exec_nodes, viewed_by_temp, view_of_temp): for p in min_memory_generator(new_exec_nodes,
viewed_by_temp,
view_of_temp):
yield [node]+p yield [node]+p
# Reset track variables # Reset track variables
...@@ -866,8 +912,19 @@ class ProfileStats(object): ...@@ -866,8 +912,19 @@ class ProfileStats(object):
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 0 compute_map[var][0] = 0
# two data structure used to mimic Python gc
viewed_by = {} # {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo
for var in fgraph.variables:
viewed_by[var] = []
view_of = {} # {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view.
# Loop all valid orders and find min peak(store in mem_bound) # Loop all valid orders and find min peak(store in mem_bound)
for order in min_memory_generator(executable_nodes, viewed_by, view_of): for order in min_memory_generator(executable_nodes,
viewed_by,
view_of):
continue continue
return mem_bound return mem_bound
...@@ -888,11 +945,13 @@ class ProfileStats(object): ...@@ -888,11 +945,13 @@ class ProfileStats(object):
new_order = fgraph.profile.node_executed_order new_order = fgraph.profile.node_executed_order
# A list of new executed node order # A list of new executed node order
new_running_memory = count_running_memory(new_order, fgraph, nodes_mem) new_running_memory = count_running_memory(new_order,
fgraph, nodes_mem)
# Store the max of some stats by any function in this profile. # Store the max of some stats by any function in this profile.
max_sum_size = max(max_sum_size, sum_size) max_sum_size = max(max_sum_size, sum_size)
max_node_memory_size = max(max_node_memory_size, old_running_memory[0]) max_node_memory_size = max(max_node_memory_size,
old_running_memory[0])
max_running_max_memory_size = max(max_running_max_memory_size, max_running_max_memory_size = max(max_running_max_memory_size,
old_running_memory[2]) old_running_memory[2])
max_node_memory_saved_by_view = max(max_node_memory_saved_by_view, max_node_memory_saved_by_view = max(max_node_memory_saved_by_view,
...@@ -901,7 +960,8 @@ class ProfileStats(object): ...@@ -901,7 +960,8 @@ class ProfileStats(object):
max_node_memory_saved_by_inplace, old_running_memory[3]) max_node_memory_saved_by_inplace, old_running_memory[3])
# Store max of some stats with new order # Store max of some stats with new order
new_max_node_memory_size = max(new_max_node_memory_size, new_running_memory[0]) new_max_node_memory_size = max(new_max_node_memory_size,
new_running_memory[0])
new_max_running_max_memory_size = max(new_max_running_max_memory_size, new_max_running_max_memory_size = max(new_max_running_max_memory_size,
new_running_memory[2]) new_running_memory[2])
new_max_node_memory_saved_by_view = max(new_max_node_memory_saved_by_view, new_max_node_memory_saved_by_view = max(new_max_node_memory_saved_by_view,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论