提交 7c7ce9a2 authored 作者: Roy Xue's avatar Roy Xue

View_of new method

上级 66aa3617
......@@ -815,10 +815,11 @@ class ProfileStats(object):
if max_mem_count > mem_bound:
continue
view_of_temp = view_of.copy()
# view_of_temp = view_of.copy()
viewof_change = []
change_track_add = defaultdict(lambda: [])
change_track_remove = defaultdict(lambda: [])
viewedby_add = defaultdict(lambda: [])
viewedby_remove = defaultdict(lambda: [])
# Use to track viewed_by changes
for var in node.outputs:
......@@ -852,10 +853,11 @@ class ProfileStats(object):
# input.
assert isinstance(ins, theano.Variable)
# We keep trac of view only again the original
origin = view_of_temp.get(ins, ins)
view_of_temp[out] = origin
origin = view_of.get(ins, ins)
view_of[out] = origin
viewof_change.append(out)
viewed_by[origin].append(out)
change_track_add[origin].append(out)
viewedby_add[origin].append(out)
else:
mem_created += var_mem[out]
idx += 1
......@@ -865,19 +867,19 @@ class ProfileStats(object):
# Mimic the combination of Theano and Python gc.
for ins in node.inputs:
assert not (ins in view_of_temp and
assert not (ins in view_of and
viewed_by[ins])
# We track of the original var, so this shouldn't happen
if (dependencies[ins] and
ins not in fgraph.outputs and
ins.owner and
all([compute_map[v][0] for v in dependencies[ins]])):
if ins not in view_of_temp and not viewed_by.get(ins, []):
if ins not in view_of and not viewed_by.get(ins, []):
mem_freed += var_mem[ins]
elif ins in view_of_temp:
origin = view_of_temp[ins]
elif ins in view_of:
origin = view_of[ins]
viewed_by[origin].remove(ins)
change_track_remove[origin].append(ins)
viewedby_remove[origin].append(ins)
if (not viewed_by[origin] and
origin not in fgraph.inputs and
not isinstance(origin, theano.Constant)):
......@@ -899,7 +901,7 @@ class ProfileStats(object):
if max_mem_count < mem_bound:
mem_bound = max_mem_count
else:
min_memory_generator(new_exec_nodes, viewed_by, view_of_temp)
min_memory_generator(new_exec_nodes, viewed_by, view_of)
# Reset track variables
mem_count -= mem_created
......@@ -908,14 +910,17 @@ class ProfileStats(object):
for var in node.outputs:
compute_map[var][0] = 0
for k_remove, v_remove in change_track_remove.iteritems():
for k_remove, v_remove in viewedby_remove.iteritems():
for i in v_remove:
viewed_by[k_remove].append(i)
for k_add, v_add in change_track_add.iteritems():
for k_add, v_add in viewedby_add.iteritems():
for i in v_add:
viewed_by[k_add].remove(i)
for k in viewof_change:
del view_of[k]
# two data structure used to mimic Python gc
viewed_by = {} # {var1: [vars that view var1]}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论