提交 7c7a5f6d authored 作者: Roy Xue's avatar Roy Xue

Updates

上级 695db885
......@@ -25,7 +25,6 @@ import numpy
import theano
from theano.gof import graph
from theano.gof import vm
from theano.configparser import AddConfigVar, BoolParam, IntParam
......@@ -663,7 +662,6 @@ class ProfileStats(object):
node_memory_saved_by_view = 0
node_memory_saved_by_inplace = 0
node_idx = 0
for node in order:
val = nodes_mem[node]
dmap = getattr(node.op, 'destroy_map', None)
......@@ -696,12 +694,19 @@ class ProfileStats(object):
max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem):
global maybe_executed, mem_count, mem_bound
global maybe_executed, mem_count, mem_bound, last_user, freed
order = []
min_order = []
node_list = list(node_list)
min_mem = sys.maxint
current_mem = 0
mem_count = 0
mem_bound = numpy.inf
computed = set()
last_user = {}
for node in node_list:
for ins in node.inputs:
last_user[ins] = 0
def check_node_state(node):
"""
......@@ -733,89 +738,78 @@ class ProfileStats(object):
if c != "output" and check_node_state(c):
executables_nodes.add(c)
mem_count = 0
mem_bound = numpy.inf
# dependencies = {}
dependencies = fgraph.profile.dependencies
print dependencies
compute = set()
last_user = {}
for node in node_list:
for ins in node.inputs:
last_user[ins] = node
last_user[ins] += 1
for outs in node.outputs:
computed.add(outs)
# for node in node_list[0].inputs[0]:
# dependencies[node] = []
# if val.owner and val.clients:
# ls = []
# for c in val.clients:
# if c[0] is not 'output':
# ls += c[0].outputs
# dependencies[val] += ls
def min_memory_generator(executables_nodes):
global mem_count, mem_bound
global mem_count, mem_bound, last_user, freed
for node in executables_nodes:
new_exec_nodes = executables_nodes.copy()
new_exec_nodes.remove(node)
mem_created = 0
mem_freed = 0
for var in node.outputs:
compute_map[var][0] = 1
idx = 0
dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None)
mem_created = sum(nodes_mem[node])
# for i in nodes_mem[node]:
# if (dmap and idx in dmap) or (vmap and idx in vmap):
# continue
# elif not isinstance(i, str):
# mem_created += i
# idx += 1
mem_count += mem_created
# Add memory created
mem_freed = 0
for ins in node.inputs:
if (ins in computed) and (ins not in fgraph.outputs) and (node == last_user[ins]):
mem_freed += var_mem[ins]
# for val in node.outputs:
# if (dependencies[val] and val.owner and val not in fgraph.outputs):
# # print dependencies[val]
# # print compute_map
# for v in dependencies[val]:
# if compute_map[v] == 1:
# print "yes"
# mem_freed += var_mem[val]
last_user[ins] -= 1
# add mem_create
for i in nodes_mem[node]:
if (dmap and idx in dmap) or (vmap and idx in vmap):
continue
elif not isinstance(i, str):
mem_created += i
idx += 1
#add mem_freed, this part is not working well
for ins in node.inputs:
if (ins in computed) and (ins not in fgraph.outputs) and (last_user[ins] == 0):
if not isinstance(var_mem[ins], str):
mem_freed += var_mem[ins]
mem_count += mem_created
mem_count -= mem_freed
# Reduce memory freed
# check if cut path now
if mem_count > mem_bound:
mem_count -= mem_created
mem_count += mem_freed
for ins in node.inputs:
last_user[ins] += 1
continue
for var in node.outputs:
compute_map[var][0] = 1
for var in node.outputs:
for c, _ in var.clients:
if c != "output" and check_node_state(c):
new_exec_nodes.add(c)
if not new_exec_nodes:
yield [node]
#update mem_bound
if mem_count < mem_bound:
# update the mem_bound
mem_bound = mem_count
else:
for p in min_memory_generator(new_exec_nodes):
yield [node]+p
# resetting part
mem_count -= mem_created
mem_count += mem_freed
# reset
for ins in node.inputs:
last_user[ins] += 1
for var in node.outputs:
compute_map[var][0] = 0
for order in min_memory_generator(executables_nodes):
continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论