提交 033d98c0 authored 作者: Roy Xue's avatar Roy Xue

viewed_by and view_of reset

上级 fc62ba11
...@@ -731,7 +731,7 @@ class ProfileStats(object): ...@@ -731,7 +731,7 @@ class ProfileStats(object):
return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view] return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view]
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
global mem_count, mem_bound, max_mem_count, viewed_by, view_of global mem_count, mem_bound, max_mem_count
node_list = list(node_list) node_list = list(node_list)
mem_count = 0 mem_count = 0
max_mem_count = 0 max_mem_count = 0
...@@ -779,14 +779,14 @@ class ProfileStats(object): ...@@ -779,14 +779,14 @@ class ProfileStats(object):
view_of = {}# {var1: original var viewed by var1} view_of = {}# {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view. # The orignal mean that we don't keep trac of all the intermediate relationship in the view.
def min_memory_generator(executable_nodes): def min_memory_generator(executable_nodes, viewed_by, view_of):
""" """
Generate all valid node order from node_list Generate all valid node order from node_list
and compute its memory peak and compute its memory peak
:param executable_nodes: Set of executable nodes :param executable_nodes: Set of executable nodes
""" """
global mem_count, mem_bound, max_mem_count, viewed_by, view_of global mem_count, mem_bound, max_mem_count
for node in executable_nodes: for node in executable_nodes:
new_exec_nodes = executable_nodes.copy() new_exec_nodes = executable_nodes.copy()
...@@ -796,6 +796,9 @@ class ProfileStats(object): ...@@ -796,6 +796,9 @@ class ProfileStats(object):
if max_mem_count > mem_bound: if max_mem_count > mem_bound:
continue continue
viewed_by_temp = viewed_by.copy()
view_of_temp = view_of.copy()
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 1 compute_map[var][0] = 1
...@@ -814,8 +817,8 @@ class ProfileStats(object): ...@@ -814,8 +817,8 @@ class ProfileStats(object):
# So the output could be different then the input. # So the output could be different then the input.
for ins in node.inputs: for ins in node.inputs:
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
view_of[out] = view_of.get(ins, ins)# This get make that we keep trac of view only again the original view_of_temp[out] = view_of_temp.get(ins, ins)# This get make that we keep trac of view only again the original
viewed_by[ins].append(out) viewed_by_temp[ins].append(out)
else: else:
mem_created += var_mem[out] mem_created += var_mem[out]
idx += 1 idx += 1
...@@ -825,19 +828,19 @@ class ProfileStats(object): ...@@ -825,19 +828,19 @@ class ProfileStats(object):
# Mimic the combination of Theano and Python gc. # Mimic the combination of Theano and Python gc.
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins]) assert not (ins in view_of_temp and viewed_by_temp[ins])
# we keep track of the original var, so this shouldn't happen # we keep track of the original var, so this shouldn't happen
if dependencies[ins] and ins not in fgraph.outputs and ins.owner: if dependencies[ins] and ins not in fgraph.outputs and ins.owner:
if all(compute_map[v] for v in dependencies[ins]): if all(compute_map[v] for v in dependencies[ins]):
if ins not in view_of and not viewed_by.get(ins, []): if ins not in view_of_temp and not viewed_by_temp.get(ins, []):
mem_freed += var_mem[ins] mem_freed += var_mem[ins]
elif ins in view_of: elif ins in view_of_temp:
origin = view_of[ins] origin = view_of_temp[ins]
viewed_by[origin].remove(ins) viewed_by_temp[origin].remove(ins)
if not viewed_by[origin] and origin not in fgraph.inputs: if not viewed_by_temp[origin] and origin not in fgraph.inputs:
mem_freed += var_mem[origin] mem_freed += var_mem[origin]
else: else:
# ins is viewed_by something else, so its memory isn't freed # ins is viewed_by_temp something else, so its memory isn't freed
pass pass
mem_count -= mem_freed mem_count -= mem_freed
...@@ -853,7 +856,7 @@ class ProfileStats(object): ...@@ -853,7 +856,7 @@ class ProfileStats(object):
if max_mem_count < mem_bound: if max_mem_count < mem_bound:
mem_bound = max_mem_count mem_bound = max_mem_count
else: else:
for p in min_memory_generator(new_exec_nodes): for p in min_memory_generator(new_exec_nodes, viewed_by_temp, view_of_temp):
yield [node]+p yield [node]+p
# Reset track variables # Reset track variables
...@@ -864,7 +867,7 @@ class ProfileStats(object): ...@@ -864,7 +867,7 @@ class ProfileStats(object):
compute_map[var][0] = 0 compute_map[var][0] = 0
# Loop all valid orders and find min peak(store in mem_bound) # Loop all valid orders and find min peak(store in mem_bound)
for order in min_memory_generator(executable_nodes): for order in min_memory_generator(executable_nodes, viewed_by, view_of):
continue continue
return mem_bound return mem_bound
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论