提交 ff41346c authored 作者: Roy Xue's avatar Roy Xue

fix bugs

1. remove the best order, cuz we won't use it now. 2. move the count_minimum_peak to the right place
上级 9a9a3b29
...@@ -740,7 +740,8 @@ class ProfileStats(object): ...@@ -740,7 +740,8 @@ class ProfileStats(object):
node_list = fgraph.nodes node_list = fgraph.nodes
best_order, minimum_peak = count_minimum_peak(node_list, fgraph) _, minimum_peak = count_minimum_peak(node_list, fgraph)
# for the best order, we dont use it now
max_minimum_peak = max(max_minimum_peak, minimum_peak) max_minimum_peak = max(max_minimum_peak, minimum_peak)
...@@ -865,6 +866,84 @@ class ProfileStats(object): ...@@ -865,6 +866,84 @@ class ProfileStats(object):
self.optimizer_profile[0].print_profile(file, self.optimizer_profile[0].print_profile(file,
self.optimizer_profile[1]) self.optimizer_profile[1])
def count_minimum_peak(node_list, fgraph):
mem_list = []
current_mem = 0
order_index = 0
min_mem = 0
order = []
compute_map = fgraph.profile.compute_map
# compute_map use to check if a node is valid
node_mem = {}
for node in self.apply_callcount.keys():
sum_dense = 0
for out in node.outputs:
sh = self.variable_shape[out]
if hasattr(out.type, 'get_size'):
v = out.type.get_size(sh)
sum_dense += v
node_mem[node] = sum_dense
# node_mem use to calculate the node memory usage
def check_node_state(node):
"""
check if an Apply node is valid(has inputs but no outputs).
:param node: apply node
"""
inputs = node.inputs
outputs = node.outputs
deps = inputs + node.destroy_dependencies
computed_ins = all(compute_map[v][0] for v in deps)
computed_outs = all(compute_map[v][0] for v in outputs)
# check if there could be a compute_map
if computed_ins and not computed_outs:
return True
else:
return False
def min_memory_generator(node_list, b=False):
global mem_list, current, order_index, min_mem
'''
enumerate all valid order( node with inputs in its compute_map)
compute the peak of all order and keep the order with the minimum peak.
return an order with minimum memory usage
:param node_list: a list of apply nodes
'''
for i in range(len(node_list)):
v = node_list[i:i+1]
if check_state(v[0]):
if len(node_list) == 1:
yield v
current_mem += node_mem[v[0]]
b = True
else:
b = False
rest = node_list[ :i] + node_list[i+1: ]
for p in count_min_memory_peak(rest):
yield v+p
current_mem += node_mem[v[0]]
if b:
mem_list.append(current_mem)
if not min_mem:
min_mem = current_mem
if current_mem < min_mem:
min_mem = current_mem
order_index = mem_list.index(current_mem)
current_mem = 0
gen = min_memory_generator(node_list)
for i in range(0, (order_index+1)):
order = gen.next()
return order, min_mem
if 0: # old code still to be ported from ProfileMode if 0: # old code still to be ported from ProfileMode
def long_print(self, file=sys.stderr, fct_name=None, message=None, def long_print(self, file=sys.stderr, fct_name=None, message=None,
...@@ -1166,84 +1245,6 @@ if 0: # old code still to be ported from ProfileMode ...@@ -1166,84 +1245,6 @@ if 0: # old code still to be ported from ProfileMode
n_ops_to_print=n_ops_to_print, print_apply=False) n_ops_to_print=n_ops_to_print, print_apply=False)
def count_minimum_peak(node_list, fgraph):
mem_list = []
current_mem = 0
order_index = 0
min_mem = 0
order = []
compute_map = fgraph.profile.compute_map
# compute_map use to check if a node is valid
node_mem = {}
for node in self.apply_callcount.keys():
sum_dense = 0
for out in node.outputs:
sh = self.variable_shape[out]
if hasattr(out.type, 'get_size'):
v = out.type.get_size(sh)
sum_dense += v
node_mem[node] = sum_dense
# node_mem use to calculate the node memory usage
def check_node_state(node):
"""
check if an Apply node is valid(has inputs but no outputs).
:param node: apply node
"""
inputs = node.inputs
outputs = node.outputs
deps = inputs + node.destroy_dependencies
computed_ins = all(compute_map[v][0] for v in deps)
computed_outs = all(compute_map[v][0] for v in outputs)
# check if there could be a compute_map
if computed_ins and not computed_outs:
return True
else:
return False
def min_memory_generator(node_list, b=False):
global mem_list, current, order_index, min_mem
'''
enumerate all valid order( node with inputs in its compute_map)
compute the peak of all order and keep the order with the minimum peak.
return an order with minimum memory usage
:param node_list: a list of apply nodes
'''
for i in range(len(node_list)):
v = node_list[i:i+1]
if check_state(v[0]):
if len(node_list) == 1:
yield v
current_mem += node_mem[v[0]]
b = True
else:
b = False
rest = node_list[ :i] + node_list[i+1: ]
for p in count_min_memory_peak(rest):
yield v+p
current_mem += node_mem[v[0]]
if b:
mem_list.append(current_mem)
if not min_mem:
min_mem = current_mem
if current_mem < min_mem:
min_mem = current_mem
order_index = mem_list.index(current_mem)
current_mem = 0
gen = min_memory_generator(node_list)
for i in range(0, (order_index+1)):
order = gen.next()
return order, min_mem
class ScanProfileStats(ProfileStats): class ScanProfileStats(ProfileStats):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论