提交 9a217c5b authored 作者: Roy Xue's avatar Roy Xue

undone version

上级 be7f4870
......@@ -728,46 +728,116 @@ class ProfileStats(object):
else:
return False
maybe_executed = set()
# maybe_executed = set()
# for var in fgraph.inputs:
# for c, _ in var.clients:
# maybe_executed.add(c)
# current = 0
# def min_memory_generator(node_list):
# '''
# enumerate all valid orders for the list of nodes in node_list
# compute the peak of all order and keep the order with the minimum peak.
# return minimum memory usage
# :param node_list: a list of apply nodes
# :param compute_map: simulate the node execution steps to update compute_map
# '''
# global maybe_executed, current
# for i in range(len(node_list)):
# v = node_list[i:i+1]
# if v[0] in maybe_executed:
# if check_node_state(v[0]):
# maybe_executed.remove(v[0])
# for node in v[0].outputs:
# compute_map[node][0] = 1
# for c, _ in node.clients:
# if c != "output":
# maybe_executed.add(c)
# if len(node_list) == 1:
# yield v
# else:
# rest = node_list[ :i] + node_list[i+1: ]
# for p in min_memory_generator(rest):
# yield v+p
# for node in v[0].outputs:
# compute_map[node][0] = 0
# maybe_executed.add(v[0])
executables_nodes = set()
compute_map = defaultdict(lambda: [0])
# compute_map use to check if a node is valid
for var in fgraph.inputs:
compute_map[var][0] = 1
for var in fgraph.inputs:
for c, _ in var.clients:
maybe_executed.add(c)
if c != "output" and check_node_state(c):
executables_nodes.add(c)
def min_memory_generator(executables_nodes):
# print executables_nodes
for node in list(executables_nodes):
executables_nodes.remove(node)
for var in node.outputs:
compute_map[var][0] = 1
for var in node.outputs:
for c, _ in var.clients:
if c != "output" and check_node_state(c) and c not in executables_nodes:
executables_nodes.append(c)
if not executables_nodes:
# executables_nodes.add(node)
yield [node]
else:
for p in min_memory_generator(executables):
yield [node]+p
for var in node.outputs:
compute_map[var][0] = 0
# if check_node_sate(node):
# executables_nodes.add(node)
min_order = []
def min_memory_generator(node_list):
'''
enumerate all valid orders for the list of nodes in node_list
compute the peak of all order and keep the order with the minimum peak.
return minimum memory usage
# for count memory
# I tested 2 way
# 1. create a new simple method
# 2. using sum(nodes_mem[v[0]])
def count_min_memory(order, thunk_old_storage, nodes_mem):
running_memory_size = 0
running_max_memory_size = 0
:param node_list: a list of apply nodes
:param compute_map: simulate the node execution steps to update compute_map
'''
global maybe_executed
for i in range(len(node_list)):
v = node_list[i:i+1]
if v[0] in maybe_executed:
if check_node_state(v[0]):
maybe_executed.remove(v[0])
for node in v[0].outputs:
compute_map[node][0] = 1
for c, _ in node.clients:
if c != "output":
maybe_executed.add(c)
if len(node_list) == 1:
yield v
else:
rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest):
yield v+p
for node in v[0].outputs:
compute_map[node][0] = 0
maybe_executed.add(v[0])
for node in order:
val = nodes_mem[node]
dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None)
min_order = []
for idx, v in enumerate(val):
# TODO check the op returned a view
if idx not in dmap and idx not in vmap and not isinstance(v, str):
running_memory_size += v
if running_memory_size > running_max_memory_size:
running_max_memory_size = running_memory_size
old_storage = thunk_old_storage[order.index(node)]
for old_s in old_storage:
old_v = var_mem[node.inputs[old_s]]
if not isinstance(old_v, str):
running_memory_size -= old_v
return running_max_memory_size
i = 0
for order in min_memory_generator(node_list):
# print i
i += 1
post_thunk_old_storage = []
computed, last_user = theano.gof.link.gc_helper(order)
for node in order:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论