提交 5ee08ced authored 作者: Roy Xue's avatar Roy Xue

Something wrong with the mem_freed

上级 adc26d0e
...@@ -25,7 +25,7 @@ import numpy ...@@ -25,7 +25,7 @@ import numpy
import theano import theano
from theano.gof import graph from theano.gof import graph
from theano.gof.vm import compute_gc_dependencies from theano.gof import vm
from theano.configparser import AddConfigVar, BoolParam, IntParam from theano.configparser import AddConfigVar, BoolParam, IntParam
...@@ -692,19 +692,16 @@ class ProfileStats(object): ...@@ -692,19 +692,16 @@ class ProfileStats(object):
return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view] return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view]
# count the minimum peak # count the minimum peak
best_order = []
minimum_peak = 0 minimum_peak = 0
max_minimum_peak = 0 max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
global maybe_executed global maybe_executed, mem_count, mem_bound
global mem_count, mem_bound
order = [] order = []
min_order = [] min_order = []
node_list = list(node_list) node_list = list(node_list)
min_mem = sys.maxint min_mem = sys.maxint
current_mem = 0 current_mem = 0
check_len = len(node_list)
def check_node_state(node): def check_node_state(node):
""" """
...@@ -737,20 +734,47 @@ class ProfileStats(object): ...@@ -737,20 +734,47 @@ class ProfileStats(object):
executables_nodes.add(c) executables_nodes.add(c)
mem_count = 0 mem_count = 0
mem_bound = 0 mem_bound = numpy.inf
# dependencies = {}
dependencies = fgraph.profile.dependencies
print dependencies
# for node in node_list[0].inputs[0]:
# dependencies[node] = []
# if val.owner and val.clients:
# ls = []
# for c in val.clients:
# if c[0] is not 'output':
# ls += c[0].outputs
# dependencies[val] += ls
def min_memory_generator(executables_nodes): def min_memory_generator(executables_nodes):
global mem_count, mem_bound global mem_count, mem_bound
for node in executables_nodes: for node in executables_nodes:
new_exec_nodes = executables_nodes.copy() new_exec_nodes = executables_nodes.copy()
new_exec_nodes.remove(node) new_exec_nodes.remove(node)
mem = count_node_memory(node)
mem_count += mem mem_created = sum(nodes_mem[node])
if mem_bound: mem_count += mem_created
# check if at this time, mem_current and mem_bound # Add memory created
# if higher use 'continue'
# dependencies = {}
mem_freed = 0
for val in node.outputs:
if (dependencies[val] and val.owner and
val not in fgraph.outputs):
if all(compute_map[v][0]
for v in dependencies[val]):
mem_freed += var_mem[val]
print mem_freed
mem_count -= mem_freed
# Reduce memory freed
if mem_count > mem_bound: if mem_count > mem_bound:
mem_count -= mem mem_count -= mem_created
mem_count += mem_freed
continue continue
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 1 compute_map[var][0] = 1
...@@ -760,77 +784,22 @@ class ProfileStats(object): ...@@ -760,77 +784,22 @@ class ProfileStats(object):
new_exec_nodes.add(c) new_exec_nodes.add(c)
if not new_exec_nodes: if not new_exec_nodes:
yield [node] yield [node]
if not mem_bound: if mem_count < mem_bound:
# initial the mem_bound
mem_bound = mem_count
elif mem_current < mem_bound:
# update the mem_bound # update the mem_bound
mem_bound = mem_current mem_bound = mem_count
else: else:
for p in min_memory_generator(new_exec_nodes): for p in min_memory_generator(new_exec_nodes):
yield [node]+p yield [node]+p
mem_count -= mem mem_count -= mem_created
mem_count += mem_freed
# reset
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 0 compute_map[var][0] = 0
def count_node_memory(node):
dependencies = compute_gc_dependencies(node)
mem = 0
for val in node.inputs:
if (dependencies[val]
and val.owner
and val not in fgraph.outputs):
mem += node.inputs.index(val)
return mem
def count_min_memory(order, thunk_old_storage, nodes_mem):
running_memory_size = 0
running_max_memory_size = 0
node_idx = 0
for node in order:
val = nodes_mem[node]
dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None)
idx = 0
for v in val:
# TODO check the op returned a view
if dmap and idx in dmap:
continue
elif vmap and idx in vmap:
continue
elif not isinstance(v, str):
running_memory_size += v
idx += 1
if running_memory_size > running_max_memory_size:
running_max_memory_size = running_memory_size
old_storage = thunk_old_storage[node_idx]
for old_s in old_storage:
old_v = var_mem[node.inputs[old_s]]
if not isinstance(old_v, str):
running_memory_size -= old_v
node_idx += 1
return running_max_memory_size
for order in min_memory_generator(executables_nodes): for order in min_memory_generator(executables_nodes):
post_thunk_old_storage = [] continue
computed, last_user = theano.gof.link.gc_helper(order)
for node in order:
post_thunk_old_storage.append([
input_idx
for input_idx, input in enumerate(node.inputs)
if (input in computed) and
(input not in fgraph.outputs) and
node == last_user[input]])
current_mem = count_min_memory(order, post_thunk_old_storage, nodes_mem)
if current_mem < min_mem:
min_mem = current_mem
min_order = order
return min_order, mem_bound return mem_bound
for fgraph, nodes_mem in fct_memory.iteritems(): for fgraph, nodes_mem in fct_memory.iteritems():
# Sum of the size of all variables in bytes # Sum of the size of all variables in bytes
...@@ -882,7 +851,7 @@ class ProfileStats(object): ...@@ -882,7 +851,7 @@ class ProfileStats(object):
node_list = fgraph.apply_nodes node_list = fgraph.apply_nodes
_, minimum_peak = count_minimum_peak(node_list, fgraph, nodes_mem) minimum_peak = count_minimum_peak(node_list, fgraph, nodes_mem)
# for the best order, we dont use it now # for the best order, we dont use it now
max_minimum_peak = max(max_minimum_peak, minimum_peak) max_minimum_peak = max(max_minimum_peak, minimum_peak)
......
...@@ -147,6 +147,9 @@ class VM(object): ...@@ -147,6 +147,9 @@ class VM(object):
if hasattr(self, 'node_cleared_order'): if hasattr(self, 'node_cleared_order'):
profile.node_cleared_order = self.node_cleared_order[:] profile.node_cleared_order = self.node_cleared_order[:]
if hasattr(self, 'dependencies'):
profile.dependencies = self.dependencies.copy()
# clear the timer info out of the buffers # clear the timer info out of the buffers
for i in xrange(len(self.call_times)): for i in xrange(len(self.call_times)):
self.call_times[i] = 0.0 self.call_times[i] = 0.0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论