提交 6879fb89 authored 作者: Roy Xue's avatar Roy Xue

updates

上级 fdb93264
...@@ -637,7 +637,6 @@ class ProfileStats(object): ...@@ -637,7 +637,6 @@ class ProfileStats(object):
new_max_node_memory_saved_by_view = 0 new_max_node_memory_saved_by_view = 0
new_max_node_memory_saved_by_inplace = 0 new_max_node_memory_saved_by_inplace = 0
def count_running_memory(order, thunk_old_storage, nodes_mem): def count_running_memory(order, thunk_old_storage, nodes_mem):
""" """
Calculate memory with specific node order Calculate memory with specific node order
...@@ -696,19 +695,13 @@ class ProfileStats(object): ...@@ -696,19 +695,13 @@ class ProfileStats(object):
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
global maybe_executed global maybe_executed
mem_list = []
order_index = 0
order = [] order = []
min_order = []
node_list = list(node_list) node_list = list(node_list)
min_mem = sys.maxint min_mem = sys.maxint
current_mem = 0 current_mem = 0
check_len = len(node_list) check_len = len(node_list)
compute_map = defaultdict(lambda: [0])
# compute_map use to check if a node is valid
for node in fgraph.inputs:
compute_map[node][0] = 1
def check_node_state(node): def check_node_state(node):
""" """
check if an Apply node is valid(has inputs but no outputs). check if an Apply node is valid(has inputs but no outputs).
...@@ -730,49 +723,10 @@ class ProfileStats(object): ...@@ -730,49 +723,10 @@ class ProfileStats(object):
else: else:
return False return False
# maybe_executed = set() compute_map = defaultdict(lambda: [0])
# for var in fgraph.inputs: # compute_map use to check if a node is valid
# for c, _ in var.clients:
# maybe_executed.add(c)
# current = 0
# def min_memory_generator(node_list):
# '''
# enumerate all valid orders for the list of nodes in node_list
# compute the peak of all order and keep the order with the minimum peak.
# return minimum memory usage
# :param node_list: a list of apply nodes
# :param compute_map: simulate the node execution steps to update compute_map
# '''
# global maybe_executed, current
# for i in range(len(node_list)):
# v = node_list[i:i+1]
# if v[0] in maybe_executed:
# if check_node_state(v[0]):
# maybe_executed.remove(v[0])
# for node in v[0].outputs:
# compute_map[node][0] = 1
# for c, _ in node.clients:
# if c != "output":
# maybe_executed.add(c)
# if len(node_list) == 1:
# yield v
# else:
# rest = node_list[ :i] + node_list[i+1: ]
# for p in min_memory_generator(rest):
# yield v+p
# for node in v[0].outputs:
# compute_map[node][0] = 0
# maybe_executed.add(v[0])
executables_nodes = set() executables_nodes = set()
# compute_map use to check if a node is valid
for var in fgraph.inputs: for var in fgraph.inputs:
compute_map[var][0] = 1 compute_map[var][0] = 1
for var in fgraph.inputs: for var in fgraph.inputs:
...@@ -780,9 +734,6 @@ class ProfileStats(object): ...@@ -780,9 +734,6 @@ class ProfileStats(object):
if c != "output" and check_node_state(c): if c != "output" and check_node_state(c):
executables_nodes.add(c) executables_nodes.add(c)
def min_memory_generator(executables_nodes): def min_memory_generator(executables_nodes):
for node in executables_nodes: for node in executables_nodes:
new_exec_nodes = executables_nodes.copy() new_exec_nodes = executables_nodes.copy()
...@@ -801,12 +752,6 @@ class ProfileStats(object): ...@@ -801,12 +752,6 @@ class ProfileStats(object):
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 0 compute_map[var][0] = 0
min_order = []
# for count memory
# I tested 2 way
# 1. create a new simple method
# 2. using sum(nodes_mem[v[0]])
def count_min_memory(order, thunk_old_storage, nodes_mem): def count_min_memory(order, thunk_old_storage, nodes_mem):
running_memory_size = 0 running_memory_size = 0
running_max_memory_size = 0 running_max_memory_size = 0
...@@ -832,12 +777,7 @@ class ProfileStats(object): ...@@ -832,12 +777,7 @@ class ProfileStats(object):
return running_max_memory_size return running_max_memory_size
i = 0
for order in min_memory_generator(executables_nodes): for order in min_memory_generator(executables_nodes):
# print i
i += 1
post_thunk_old_storage = [] post_thunk_old_storage = []
computed, last_user = theano.gof.link.gc_helper(order) computed, last_user = theano.gof.link.gc_helper(order)
for node in order: for node in order:
......
...@@ -42,8 +42,8 @@ def test_profiling(): ...@@ -42,8 +42,8 @@ def test_profiling():
# regression testing for future algo speed up # regression testing for future algo speed up
the_string = buf.getvalue() the_string = buf.getvalue()
assert "Max if linker=cvm(default): 8224KB (16408KB)" in the_string assert "Max if linker=cvm(default): 8208KB (16400KB)" in the_string
assert "Minimum peak from all valid apply node order is 8208KB" in the_string assert "Minimum peak from all valid apply node order is 8192KB" in the_string
finally: finally:
theano.config.profile = old1 theano.config.profile = old1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论