提交 01f1a753 authored 作者: Roy Xue's avatar Roy Xue

modify view_of and viewed_by place.

上级 f5802f01
...@@ -672,13 +672,14 @@ class ProfileStats(object): ...@@ -672,13 +672,14 @@ class ProfileStats(object):
dependencies = fgraph.profile.dependencies dependencies = fgraph.profile.dependencies
# two data structure used to mimic Python gc # two data structure used to mimic Python gc
view_of = {} # {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view.
viewed_by = {}# {var1: [vars that view var1]} viewed_by = {}# {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value. # The len of the list is the value of python ref count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo # This is more safe to help detect potential bug in the algo
for var in fgraph.variable for node in fgraph.apply_nodes:
for var in node.inputs:
viewed_by[var] = []
view_of = {} # {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view.
for node in order: for node in order:
idx = 0 idx = 0
...@@ -702,7 +703,7 @@ class ProfileStats(object): ...@@ -702,7 +703,7 @@ class ProfileStats(object):
# This is needed for destroy_map in case it return a partial view that is destroyed. # This is needed for destroy_map in case it return a partial view that is destroyed.
# So the output could be different then the input. # So the output could be different then the input.
for ins in node.inputs: for ins in node.inputs:
assert len[ins] == 1 # assert len[ins] == 1
view_of[out] = view_of.get(ins, ins)# This get make that we keep trac of view only again the original view_of[out] = view_of.get(ins, ins)# This get make that we keep trac of view only again the original
viewed_by[ins].append(out) viewed_by[ins].append(out)
else: else:
...@@ -714,7 +715,7 @@ class ProfileStats(object): ...@@ -714,7 +715,7 @@ class ProfileStats(object):
# Mimic the combination of Theano and Python gc # Mimic the combination of Theano and Python gc
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of and i in viewed_by) assert not (ins in view_of and ins in viewed_by)
# we keep trac of the original var, so this shouldn't happen # we keep trac of the original var, so this shouldn't happen
if dependencies[ins] and ins not in fgraph.outputs: if dependencies[ins] and ins not in fgraph.outputs:
if ins not in view_of and not viewed_by.get(ins, []): if ins not in view_of and not viewed_by.get(ins, []):
...@@ -731,7 +732,7 @@ class ProfileStats(object): ...@@ -731,7 +732,7 @@ class ProfileStats(object):
return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view] return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view]
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
global maybe_executed, mem_count, mem_bound, max_mem_count global mem_count, mem_bound, max_mem_count, viewed_by, view_of
node_list = list(node_list) node_list = list(node_list)
mem_count = 0 mem_count = 0
max_mem_count = 0 max_mem_count = 0
...@@ -769,11 +770,13 @@ class ProfileStats(object): ...@@ -769,11 +770,13 @@ class ProfileStats(object):
if c != "output" and check_node_state(c): if c != "output" and check_node_state(c):
executable_nodes.add(c) executable_nodes.add(c)
viewed_by = {}# {var1:[vars that view var1]} # two data structure used to mimic Python gc
# The len of the list is the value of python ref count. But we use a list, not just the ref count value. viewed_by = {}# {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo # This is more safe to help detect potential bug in the algo
for i in node.inputs: for node in fgraph.apply_nodes:
viewed_by[i] = [] for var in node.inputs:
viewed_by[var] = []
view_of = {}# {var1: original var viewed by var1} view_of = {}# {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view. # The orignal mean that we don't keep trac of all the intermediate relationship in the view.
...@@ -784,7 +787,7 @@ class ProfileStats(object): ...@@ -784,7 +787,7 @@ class ProfileStats(object):
:param executable_nodes: Set of executable nodes :param executable_nodes: Set of executable nodes
""" """
global mem_count, mem_bound, max_mem_count global mem_count, mem_bound, max_mem_count, viewed_by, view_of
for node in executable_nodes: for node in executable_nodes:
new_exec_nodes = executable_nodes.copy() new_exec_nodes = executable_nodes.copy()
...@@ -797,6 +800,8 @@ class ProfileStats(object): ...@@ -797,6 +800,8 @@ class ProfileStats(object):
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 1 compute_map[var][0] = 1
print view_of, viewed_by
mem_created = 0 mem_created = 0
mem_freed = 0 mem_freed = 0
max_storage = max_mem_count max_storage = max_mem_count
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论