提交 033d98c0 authored 作者: Roy Xue's avatar Roy Xue

viewed_by and view_of reset

上级 fc62ba11
......@@ -731,7 +731,7 @@ class ProfileStats(object):
return [node_memory_size, running_memory_size, running_max_memory_size, node_memory_saved_by_inplace, node_memory_saved_by_view]
def count_minimum_peak(node_list, fgraph, nodes_mem):
global mem_count, mem_bound, max_mem_count, viewed_by, view_of
global mem_count, mem_bound, max_mem_count
node_list = list(node_list)
mem_count = 0
max_mem_count = 0
......@@ -779,14 +779,14 @@ class ProfileStats(object):
view_of = {}# {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view.
def min_memory_generator(executable_nodes):
def min_memory_generator(executable_nodes, viewed_by, view_of):
"""
Generate all valid node order from node_list
and compute its memory peak
:param executable_nodes: Set of executable nodes
"""
global mem_count, mem_bound, max_mem_count, viewed_by, view_of
global mem_count, mem_bound, max_mem_count
for node in executable_nodes:
new_exec_nodes = executable_nodes.copy()
......@@ -796,6 +796,9 @@ class ProfileStats(object):
if max_mem_count > mem_bound:
continue
viewed_by_temp = viewed_by.copy()
view_of_temp = view_of.copy()
for var in node.outputs:
compute_map[var][0] = 1
......@@ -814,8 +817,8 @@ class ProfileStats(object):
# So the output could be different then the input.
for ins in node.inputs:
assert isinstance(ins, theano.Variable)
view_of[out] = view_of.get(ins, ins)# This get make that we keep trac of view only again the original
viewed_by[ins].append(out)
view_of_temp[out] = view_of_temp.get(ins, ins)# This get make that we keep trac of view only again the original
viewed_by_temp[ins].append(out)
else:
mem_created += var_mem[out]
idx += 1
......@@ -825,19 +828,19 @@ class ProfileStats(object):
# Mimic the combination of Theano and Python gc.
for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins])
assert not (ins in view_of_temp and viewed_by_temp[ins])
# we keep track of the original var, so this shouldn't happen
if dependencies[ins] and ins not in fgraph.outputs and ins.owner:
if all(compute_map[v] for v in dependencies[ins]):
if ins not in view_of and not viewed_by.get(ins, []):
if ins not in view_of_temp and not viewed_by_temp.get(ins, []):
mem_freed += var_mem[ins]
elif ins in view_of:
origin = view_of[ins]
viewed_by[origin].remove(ins)
if not viewed_by[origin] and origin not in fgraph.inputs:
elif ins in view_of_temp:
origin = view_of_temp[ins]
viewed_by_temp[origin].remove(ins)
if not viewed_by_temp[origin] and origin not in fgraph.inputs:
mem_freed += var_mem[origin]
else:
# ins is viewed_by something else, so its memory isn't freed
# ins is viewed_by_temp something else, so its memory isn't freed
pass
mem_count -= mem_freed
......@@ -853,7 +856,7 @@ class ProfileStats(object):
if max_mem_count < mem_bound:
mem_bound = max_mem_count
else:
for p in min_memory_generator(new_exec_nodes):
for p in min_memory_generator(new_exec_nodes, viewed_by_temp, view_of_temp):
yield [node]+p
# Reset track variables
......@@ -864,7 +867,7 @@ class ProfileStats(object):
compute_map[var][0] = 0
# Loop all valid orders and find min peak(store in mem_bound)
for order in min_memory_generator(executable_nodes):
for order in min_memory_generator(executable_nodes, viewed_by, view_of):
continue
return mem_bound
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论