提交 ca78868f authored 作者: Roy Xue's avatar Roy Xue

Updates:)

上级 a97ad2b6
...@@ -702,12 +702,18 @@ class ProfileStats(object): ...@@ -702,12 +702,18 @@ class ProfileStats(object):
mem_count = 0 mem_count = 0
max_mem_count = 0 max_mem_count = 0
mem_bound = numpy.inf mem_bound = numpy.inf
dependencies = fgraph.profile.dependencies
# Initial compute_map which is used to check if a node is valid
compute_map = defaultdict(lambda: [0])
for var in fgraph.inputs:
compute_map[var][0] = 1
def check_node_state(node): def check_node_state(node):
""" """
check if an Apply node is valid(has inputs but no outputs). check if an Apply node is valid(has inputs).
:param node: apply node :param node: Apply Node
""" """
inputs = node.inputs inputs = node.inputs
outputs = node.outputs outputs = node.outputs
...@@ -722,39 +728,29 @@ class ProfileStats(object): ...@@ -722,39 +728,29 @@ class ProfileStats(object):
else: else:
return False return False
compute_map = defaultdict(lambda: [0]) # Initial executable_nodes
# compute_map use to check if a node is valid executable_nodes = set()
executables_nodes = set()
for var in fgraph.inputs:
compute_map[var][0] = 1
for var in fgraph.inputs: for var in fgraph.inputs:
for c, _ in var.clients: for c, _ in var.clients:
if c != "output" and check_node_state(c): if c != "output" and check_node_state(c):
executables_nodes.add(c) executable_nodes.add(c)
dependencies = fgraph.profile.dependencies def min_memory_generator(executable_nodes):
# for node in node_list[0].inputs[0]: """
# dependencies[node] = [] Generate all valid node order from node_list
# if val.owner and val.clients: and compute its memory peaf
# ls = [] """
# for c in val.clients:
# if c[0] is not 'output':
# ls += c[0].outputs
# dependencies[val] += ls
def min_memory_generator(executables_nodes):
global mem_count, mem_bound, max_mem_count, max_mem_count global mem_count, mem_bound, max_mem_count, max_mem_count
for node in executables_nodes:
new_exec_nodes = executables_nodes.copy() for node in executable_nodes:
new_exec_nodes = executable_nodes.copy()
new_exec_nodes.remove(node) new_exec_nodes.remove(node)
mem_created = 0 mem_created = 0
mem_freed = 0 mem_freed = 0
if mem_count > mem_bound: # check if we cut path now
mem_count -= mem_created if max_mem_count > mem_bound:
mem_count += mem_freed
continue continue
for var in node.outputs: for var in node.outputs:
...@@ -764,7 +760,7 @@ class ProfileStats(object): ...@@ -764,7 +760,7 @@ class ProfileStats(object):
dmap = getattr(node.op, 'destroy_map', None) dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None) vmap = getattr(node.op, 'view_map', None)
# add mem_create # Compute mem_create
for i in nodes_mem[node]: for i in nodes_mem[node]:
if (dmap and idx in dmap) or (vmap and idx in vmap): if (dmap and idx in dmap) or (vmap and idx in vmap):
continue continue
...@@ -776,7 +772,7 @@ class ProfileStats(object): ...@@ -776,7 +772,7 @@ class ProfileStats(object):
if mem_count > max_mem_count: if mem_count > max_mem_count:
max_mem_count = mem_count max_mem_count = mem_count
#add mem_freed, this part is not working well #Compute mem_freed
for val in node.inputs: for val in node.inputs:
if (dependencies[val] and val.owner and val not in fgraph.outputs): if (dependencies[val] and val.owner and val not in fgraph.outputs):
if all(compute_map[v] for v in dependencies[val]): if all(compute_map[v] for v in dependencies[val]):
...@@ -785,8 +781,6 @@ class ProfileStats(object): ...@@ -785,8 +781,6 @@ class ProfileStats(object):
mem_count -= mem_freed mem_count -= mem_freed
# check if cut path now
for var in node.outputs: for var in node.outputs:
for c, _ in var.clients: for c, _ in var.clients:
if c != "output" and check_node_state(c): if c != "output" and check_node_state(c):
...@@ -794,22 +788,21 @@ class ProfileStats(object): ...@@ -794,22 +788,21 @@ class ProfileStats(object):
if not new_exec_nodes: if not new_exec_nodes:
yield [node] yield [node]
#update mem_bound #Check and Update mem_bound
if max_mem_count < mem_bound: if max_mem_count < mem_bound:
mem_bound = max_mem_count mem_bound = max_mem_count
else: else:
for p in min_memory_generator(new_exec_nodes): for p in min_memory_generator(new_exec_nodes):
yield [node]+p yield [node]+p
# resetting part # Reset track variables
mem_count -= mem_created mem_count -= mem_created
max_mem_count -= mem_created max_mem_count -= mem_created
mem_count += mem_freed mem_count += mem_freed
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 0 compute_map[var][0] = 0
for order in min_memory_generator(executable_nodes):
for order in min_memory_generator(executables_nodes):
continue continue
return mem_bound return mem_bound
......
...@@ -42,9 +42,8 @@ def test_profiling(): ...@@ -42,9 +42,8 @@ def test_profiling():
# regression testing for future algo speed up # regression testing for future algo speed up
the_string = buf.getvalue() the_string = buf.getvalue()
# assert "Max if linker=cvm(default): 8208KB (16400KB)" in the_string assert "Max if linker=cvm(default): 8208KB (16400KB)" in the_string
# assert "Minimum peak from all valid apply node order is 8192KB" in the_string assert "Minimum peak from all valid apply node order is 8192KB" in the_string
print the_string
finally: finally:
theano.config.profile = old1 theano.config.profile = old1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论