提交 43abc77b authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2226 from RoyXue/faster_algo

Using DONE_SET method to make algorithm faster
...@@ -797,6 +797,8 @@ class ProfileStats(object): ...@@ -797,6 +797,8 @@ class ProfileStats(object):
mem_bound = numpy.inf mem_bound = numpy.inf
# This take only the inputs/outputs dependencies. # This take only the inputs/outputs dependencies.
dependencies = fgraph.profile.dependencies dependencies = fgraph.profile.dependencies
done_set = set([])
done_dict = {}
# Initial compute_map which is used to check if a node is valid # Initial compute_map which is used to check if a node is valid
compute_map = defaultdict(lambda: [0]) compute_map = defaultdict(lambda: [0])
...@@ -909,21 +911,27 @@ class ProfileStats(object): ...@@ -909,21 +911,27 @@ class ProfileStats(object):
mem_count -= mem_freed mem_count -= mem_freed
for var in node.outputs: done_set.add(node)
for c, _ in var.clients: frozen_set = frozenset(done_set)
if c != "output": if done_dict.get(frozen_set, max_mem_count+1) > max_mem_count:
deps = c.inputs + c.destroy_dependencies done_dict[frozen_set] = max_mem_count
if all(compute_map[v][0] for v in deps):
new_exec_nodes.add(c) for var in node.outputs:
for c, _ in var.clients:
if not new_exec_nodes: if c != "output":
# Check and Update mem_bound deps = c.inputs + c.destroy_dependencies
if max_mem_count < mem_bound: if all(compute_map[v][0] for v in deps):
mem_bound = max_mem_count new_exec_nodes.add(c)
else:
min_memory_generator(new_exec_nodes, viewed_by, view_of) if not new_exec_nodes:
# Check and Update mem_bound
if max_mem_count < mem_bound:
mem_bound = max_mem_count
else:
min_memory_generator(new_exec_nodes, viewed_by, view_of)
# Reset track variables # Reset track variables
done_set.remove(node)
mem_count -= mem_created mem_count -= mem_created
max_mem_count = max_storage max_mem_count = max_storage
mem_count += mem_freed mem_count += mem_freed
...@@ -1025,7 +1033,7 @@ class ProfileStats(object): ...@@ -1025,7 +1033,7 @@ class ProfileStats(object):
new_max_running_max_memory_size / 1024.)), int(round( new_max_running_max_memory_size / 1024.)), int(round(
max_running_max_memory_size / 1024.))) max_running_max_memory_size / 1024.)))
if min_max_peak: if min_max_peak:
print >> file, " Minimum peak from all valid apply node order is %dKB(took %f.2s to compute)" % (int(round( print >> file, " Minimum peak from all valid apply node order is %dKB(took %.3fs to compute)" % (int(round(
min_max_peak / 1024.)), min_peak_time) min_max_peak / 1024.)), min_peak_time)
print >> file, " Memory saved if views are used: %dKB (%dKB)" % (int( print >> file, " Memory saved if views are used: %dKB (%dKB)" % (int(
round(new_max_node_memory_saved_by_view / 1024.)), int( round(new_max_node_memory_saved_by_view / 1024.)), int(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论