提交 6d5e4ff2 authored 作者: Roy Xue's avatar Roy Xue

Updates and fix algo

上级 ca78868f
...@@ -740,7 +740,7 @@ class ProfileStats(object): ...@@ -740,7 +740,7 @@ class ProfileStats(object):
Generate all valid node order from node_list Generate all valid node order from node_list
and compute its memory peaf and compute its memory peaf
""" """
global mem_count, mem_bound, max_mem_count, max_mem_count global mem_count, mem_bound, max_mem_count
for node in executable_nodes: for node in executable_nodes:
new_exec_nodes = executable_nodes.copy() new_exec_nodes = executable_nodes.copy()
...@@ -749,6 +749,8 @@ class ProfileStats(object): ...@@ -749,6 +749,8 @@ class ProfileStats(object):
mem_created = 0 mem_created = 0
mem_freed = 0 mem_freed = 0
max_storage = max_mem_count
# check if we cut path now # check if we cut path now
if max_mem_count > mem_bound: if max_mem_count > mem_bound:
continue continue
...@@ -776,8 +778,7 @@ class ProfileStats(object): ...@@ -776,8 +778,7 @@ class ProfileStats(object):
for val in node.inputs: for val in node.inputs:
if (dependencies[val] and val.owner and val not in fgraph.outputs): if (dependencies[val] and val.owner and val not in fgraph.outputs):
if all(compute_map[v] for v in dependencies[val]): if all(compute_map[v] for v in dependencies[val]):
if not dmap and not vmap: mem_freed += var_mem[val]
mem_freed += var_mem[val]
mem_count -= mem_freed mem_count -= mem_freed
...@@ -797,7 +798,7 @@ class ProfileStats(object): ...@@ -797,7 +798,7 @@ class ProfileStats(object):
# Reset track variables # Reset track variables
mem_count -= mem_created mem_count -= mem_created
max_mem_count -= mem_created max_mem_count = max_storage
mem_count += mem_freed mem_count += mem_freed
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 0 compute_map[var][0] = 0
......
...@@ -44,7 +44,7 @@ def test_profiling(): ...@@ -44,7 +44,7 @@ def test_profiling():
the_string = buf.getvalue() the_string = buf.getvalue()
assert "Max if linker=cvm(default): 8208KB (16400KB)" in the_string assert "Max if linker=cvm(default): 8208KB (16400KB)" in the_string
assert "Minimum peak from all valid apply node order is 8192KB" in the_string assert "Minimum peak from all valid apply node order is 8192KB" in the_string
finally: finally:
theano.config.profile = old1 theano.config.profile = old1
theano.config.profile_memory = old2 theano.config.profile_memory = old2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论