提交 7c7a5f6d authored 作者: Roy Xue's avatar Roy Xue

Updates

上级 695db885
...@@ -25,7 +25,6 @@ import numpy ...@@ -25,7 +25,6 @@ import numpy
import theano import theano
from theano.gof import graph from theano.gof import graph
from theano.gof import vm
from theano.configparser import AddConfigVar, BoolParam, IntParam from theano.configparser import AddConfigVar, BoolParam, IntParam
...@@ -663,7 +662,6 @@ class ProfileStats(object): ...@@ -663,7 +662,6 @@ class ProfileStats(object):
node_memory_saved_by_view = 0 node_memory_saved_by_view = 0
node_memory_saved_by_inplace = 0 node_memory_saved_by_inplace = 0
node_idx = 0 node_idx = 0
for node in order: for node in order:
val = nodes_mem[node] val = nodes_mem[node]
dmap = getattr(node.op, 'destroy_map', None) dmap = getattr(node.op, 'destroy_map', None)
...@@ -696,12 +694,19 @@ class ProfileStats(object): ...@@ -696,12 +694,19 @@ class ProfileStats(object):
max_minimum_peak = 0 max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
global maybe_executed, mem_count, mem_bound global maybe_executed, mem_count, mem_bound, last_user, freed
order = [] order = []
min_order = [] min_order = []
node_list = list(node_list) node_list = list(node_list)
min_mem = sys.maxint min_mem = sys.maxint
current_mem = 0 current_mem = 0
mem_count = 0
mem_bound = numpy.inf
computed = set()
last_user = {}
for node in node_list:
for ins in node.inputs:
last_user[ins] = 0
def check_node_state(node): def check_node_state(node):
""" """
...@@ -733,89 +738,78 @@ class ProfileStats(object): ...@@ -733,89 +738,78 @@ class ProfileStats(object):
if c != "output" and check_node_state(c): if c != "output" and check_node_state(c):
executables_nodes.add(c) executables_nodes.add(c)
mem_count = 0
mem_bound = numpy.inf
# dependencies = {}
dependencies = fgraph.profile.dependencies
print dependencies
compute = set()
last_user = {}
for node in node_list: for node in node_list:
for ins in node.inputs: for ins in node.inputs:
last_user[ins] = node last_user[ins] += 1
for outs in node.outputs: for outs in node.outputs:
computed.add(outs) computed.add(outs)
# for node in node_list[0].inputs[0]:
# dependencies[node] = []
# if val.owner and val.clients:
# ls = []
# for c in val.clients:
# if c[0] is not 'output':
# ls += c[0].outputs
# dependencies[val] += ls
def min_memory_generator(executables_nodes): def min_memory_generator(executables_nodes):
global mem_count, mem_bound global mem_count, mem_bound, last_user, freed
for node in executables_nodes: for node in executables_nodes:
new_exec_nodes = executables_nodes.copy() new_exec_nodes = executables_nodes.copy()
new_exec_nodes.remove(node) new_exec_nodes.remove(node)
mem_created = 0 mem_created = 0
mem_freed = 0
for var in node.outputs:
compute_map[var][0] = 1
idx = 0 idx = 0
dmap = getattr(node.op, 'destroy_map', None) dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None) vmap = getattr(node.op, 'view_map', None)
mem_created = sum(nodes_mem[node])
# for i in nodes_mem[node]:
# if (dmap and idx in dmap) or (vmap and idx in vmap):
# continue
# elif not isinstance(i, str):
# mem_created += i
# idx += 1
mem_count += mem_created
# Add memory created
mem_freed = 0
for ins in node.inputs: for ins in node.inputs:
if (ins in computed) and (ins not in fgraph.outputs) and (node == last_user[ins]): last_user[ins] -= 1
mem_freed += var_mem[ins]
# for val in node.outputs: # add mem_create
# if (dependencies[val] and val.owner and val not in fgraph.outputs): for i in nodes_mem[node]:
# # print dependencies[val] if (dmap and idx in dmap) or (vmap and idx in vmap):
# # print compute_map continue
# for v in dependencies[val]: elif not isinstance(i, str):
# if compute_map[v] == 1: mem_created += i
# print "yes" idx += 1
# mem_freed += var_mem[val]
#add mem_freed, this part is not working well
for ins in node.inputs:
if (ins in computed) and (ins not in fgraph.outputs) and (last_user[ins] == 0):
if not isinstance(var_mem[ins], str):
mem_freed += var_mem[ins]
mem_count += mem_created
mem_count -= mem_freed mem_count -= mem_freed
# Reduce memory freed
# check if cut path now
if mem_count > mem_bound: if mem_count > mem_bound:
mem_count -= mem_created mem_count -= mem_created
mem_count += mem_freed mem_count += mem_freed
for ins in node.inputs:
last_user[ins] += 1
continue continue
for var in node.outputs:
compute_map[var][0] = 1
for var in node.outputs: for var in node.outputs:
for c, _ in var.clients: for c, _ in var.clients:
if c != "output" and check_node_state(c): if c != "output" and check_node_state(c):
new_exec_nodes.add(c) new_exec_nodes.add(c)
if not new_exec_nodes: if not new_exec_nodes:
yield [node] yield [node]
#update mem_bound
if mem_count < mem_bound: if mem_count < mem_bound:
# update the mem_bound
mem_bound = mem_count mem_bound = mem_count
else: else:
for p in min_memory_generator(new_exec_nodes): for p in min_memory_generator(new_exec_nodes):
yield [node]+p yield [node]+p
# resetting part
mem_count -= mem_created mem_count -= mem_created
mem_count += mem_freed mem_count += mem_freed
# reset for ins in node.inputs:
last_user[ins] += 1
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 0 compute_map[var][0] = 0
for order in min_memory_generator(executables_nodes): for order in min_memory_generator(executables_nodes):
continue continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论