提交 43abc77b authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2226 from RoyXue/faster_algo

Using DONE_SET method to make algorithm faster
...@@ -797,6 +797,8 @@ class ProfileStats(object): ...@@ -797,6 +797,8 @@ class ProfileStats(object):
mem_bound = numpy.inf mem_bound = numpy.inf
# This take only the inputs/outputs dependencies. # This take only the inputs/outputs dependencies.
dependencies = fgraph.profile.dependencies dependencies = fgraph.profile.dependencies
done_set = set([])
done_dict = {}
# Initial compute_map which is used to check if a node is valid # Initial compute_map which is used to check if a node is valid
compute_map = defaultdict(lambda: [0]) compute_map = defaultdict(lambda: [0])
...@@ -909,6 +911,11 @@ class ProfileStats(object): ...@@ -909,6 +911,11 @@ class ProfileStats(object):
mem_count -= mem_freed mem_count -= mem_freed
done_set.add(node)
frozen_set = frozenset(done_set)
if done_dict.get(frozen_set, max_mem_count+1) > max_mem_count:
done_dict[frozen_set] = max_mem_count
for var in node.outputs: for var in node.outputs:
for c, _ in var.clients: for c, _ in var.clients:
if c != "output": if c != "output":
...@@ -924,6 +931,7 @@ class ProfileStats(object): ...@@ -924,6 +931,7 @@ class ProfileStats(object):
min_memory_generator(new_exec_nodes, viewed_by, view_of) min_memory_generator(new_exec_nodes, viewed_by, view_of)
# Reset track variables # Reset track variables
done_set.remove(node)
mem_count -= mem_created mem_count -= mem_created
max_mem_count = max_storage max_mem_count = max_storage
mem_count += mem_freed mem_count += mem_freed
...@@ -1025,7 +1033,7 @@ class ProfileStats(object): ...@@ -1025,7 +1033,7 @@ class ProfileStats(object):
new_max_running_max_memory_size / 1024.)), int(round( new_max_running_max_memory_size / 1024.)), int(round(
max_running_max_memory_size / 1024.))) max_running_max_memory_size / 1024.)))
if min_max_peak: if min_max_peak:
print >> file, " Minimum peak from all valid apply node order is %dKB(took %f.2s to compute)" % (int(round( print >> file, " Minimum peak from all valid apply node order is %dKB(took %.3fs to compute)" % (int(round(
min_max_peak / 1024.)), min_peak_time) min_max_peak / 1024.)), min_peak_time)
print >> file, " Memory saved if views are used: %dKB (%dKB)" % (int( print >> file, " Memory saved if views are used: %dKB (%dKB)" % (int(
round(new_max_node_memory_saved_by_view / 1024.)), int( round(new_max_node_memory_saved_by_view / 1024.)), int(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论