提交 5ee08ced authored 作者: Roy Xue's avatar Roy Xue

Something wrong with the mem_freed

上级 adc26d0e
......@@ -25,7 +25,7 @@ import numpy
import theano
from theano.gof import graph
from theano.gof.vm import compute_gc_dependencies
from theano.gof import vm
from theano.configparser import AddConfigVar, BoolParam, IntParam
......@@ -692,19 +692,16 @@ class ProfileStats(object):
return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view]
# count the minimum peak
best_order = []
minimum_peak = 0
max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem):
global maybe_executed
global mem_count, mem_bound
global maybe_executed, mem_count, mem_bound
order = []
min_order = []
node_list = list(node_list)
min_mem = sys.maxint
current_mem = 0
check_len = len(node_list)
def check_node_state(node):
"""
......@@ -737,20 +734,47 @@ class ProfileStats(object):
executables_nodes.add(c)
mem_count = 0
mem_bound = 0
mem_bound = numpy.inf
# dependencies = {}
dependencies = fgraph.profile.dependencies
print dependencies
# for node in node_list[0].inputs[0]:
# dependencies[node] = []
# if val.owner and val.clients:
# ls = []
# for c in val.clients:
# if c[0] is not 'output':
# ls += c[0].outputs
# dependencies[val] += ls
def min_memory_generator(executables_nodes):
global mem_count, mem_bound
for node in executables_nodes:
new_exec_nodes = executables_nodes.copy()
new_exec_nodes.remove(node)
mem = count_node_memory(node)
mem_count += mem
if mem_bound:
# check if at this time, mem_current and mem_bound
# if higher use 'continue'
mem_created = sum(nodes_mem[node])
mem_count += mem_created
# Add memory created
# dependencies = {}
mem_freed = 0
for val in node.outputs:
if (dependencies[val] and val.owner and
val not in fgraph.outputs):
if all(compute_map[v][0]
for v in dependencies[val]):
mem_freed += var_mem[val]
print mem_freed
mem_count -= mem_freed
# Reduce memory freed
if mem_count > mem_bound:
mem_count -= mem
mem_count -= mem_created
mem_count += mem_freed
continue
for var in node.outputs:
compute_map[var][0] = 1
......@@ -760,77 +784,22 @@ class ProfileStats(object):
new_exec_nodes.add(c)
if not new_exec_nodes:
yield [node]
if not mem_bound:
# initial the mem_bound
mem_bound = mem_count
elif mem_current < mem_bound:
if mem_count < mem_bound:
# update the mem_bound
mem_bound = mem_current
mem_bound = mem_count
else:
for p in min_memory_generator(new_exec_nodes):
yield [node]+p
mem_count -= mem
mem_count -= mem_created
mem_count += mem_freed
# reset
for var in node.outputs:
compute_map[var][0] = 0
def count_node_memory(node):
dependencies = compute_gc_dependencies(node)
mem = 0
for val in node.inputs:
if (dependencies[val]
and val.owner
and val not in fgraph.outputs):
mem += node.inputs.index(val)
return mem
def count_min_memory(order, thunk_old_storage, nodes_mem):
running_memory_size = 0
running_max_memory_size = 0
node_idx = 0
for node in order:
val = nodes_mem[node]
dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None)
idx = 0
for v in val:
# TODO check the op returned a view
if dmap and idx in dmap:
continue
elif vmap and idx in vmap:
continue
elif not isinstance(v, str):
running_memory_size += v
idx += 1
if running_memory_size > running_max_memory_size:
running_max_memory_size = running_memory_size
old_storage = thunk_old_storage[node_idx]
for old_s in old_storage:
old_v = var_mem[node.inputs[old_s]]
if not isinstance(old_v, str):
running_memory_size -= old_v
node_idx += 1
return running_max_memory_size
for order in min_memory_generator(executables_nodes):
post_thunk_old_storage = []
computed, last_user = theano.gof.link.gc_helper(order)
for node in order:
post_thunk_old_storage.append([
input_idx
for input_idx, input in enumerate(node.inputs)
if (input in computed) and
(input not in fgraph.outputs) and
node == last_user[input]])
current_mem = count_min_memory(order, post_thunk_old_storage, nodes_mem)
if current_mem < min_mem:
min_mem = current_mem
min_order = order
continue
return min_order, mem_bound
return mem_bound
for fgraph, nodes_mem in fct_memory.iteritems():
# Sum of the size of all variables in bytes
......@@ -882,7 +851,7 @@ class ProfileStats(object):
node_list = fgraph.apply_nodes
_, minimum_peak = count_minimum_peak(node_list, fgraph, nodes_mem)
minimum_peak = count_minimum_peak(node_list, fgraph, nodes_mem)
# for the best order, we dont use it now
max_minimum_peak = max(max_minimum_peak, minimum_peak)
......
......@@ -147,6 +147,9 @@ class VM(object):
if hasattr(self, 'node_cleared_order'):
profile.node_cleared_order = self.node_cleared_order[:]
if hasattr(self, 'dependencies'):
profile.dependencies = self.dependencies.copy()
# clear the timer info out of the buffers
for i in xrange(len(self.call_times)):
self.call_times[i] = 0.0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论