提交 19223fd6 authored 作者: Lijun Xue's avatar Lijun Xue

reallocation in Loop/LoopGC

上级 3ddc19f5
...@@ -155,7 +155,6 @@ class VM(object): ...@@ -155,7 +155,6 @@ class VM(object):
else: else:
profile.dependencies = {} profile.dependencies = {}
# clear the timer info out of the buffers # clear the timer info out of the buffers
for i in xrange(len(self.call_times)): for i in xrange(len(self.call_times)):
self.call_times[i] = 0.0 self.call_times[i] = 0.0
...@@ -169,6 +168,10 @@ class Loop(VM): ...@@ -169,6 +168,10 @@ class Loop(VM):
No garbage collection is allowed on intermediate results. No garbage collection is allowed on intermediate results.
""" """
def __init__(self, nodes, thunks, pre_call_clear, reallocated_info):
super(LoopGC, self).__init__(nodes, thunks, pre_call_clear)
self.reallocated_info = reallocated_info
def __call__(self): def __call__(self):
if self.time_thunks: if self.time_thunks:
for cont in self.pre_call_clear: for cont in self.pre_call_clear:
...@@ -181,6 +184,13 @@ class Loop(VM): ...@@ -181,6 +184,13 @@ class Loop(VM):
t1 = time.time() t1 = time.time()
self.call_counts[i] += 1 self.call_counts[i] += 1
self.call_times[i] += t1 - t0 self.call_times[i] += t1 - t0
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
else: else:
...@@ -189,6 +199,13 @@ class Loop(VM): ...@@ -189,6 +199,13 @@ class Loop(VM):
try: try:
for thunk, node in zip(self.thunks, self.nodes): for thunk, node in zip(self.thunks, self.nodes):
thunk() thunk()
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
...@@ -200,9 +217,10 @@ class LoopGC(VM): ...@@ -200,9 +217,10 @@ class LoopGC(VM):
Garbage collection is possible on intermediate results. Garbage collection is possible on intermediate results.
""" """
def __init__(self, nodes, thunks, pre_call_clear, post_thunk_clear): def __init__(self, nodes, thunks, pre_call_clear, post_thunk_clear, reallocated_info):
super(LoopGC, self).__init__(nodes, thunks, pre_call_clear) super(LoopGC, self).__init__(nodes, thunks, pre_call_clear)
self.post_thunk_clear = post_thunk_clear self.post_thunk_clear = post_thunk_clear
self.reallocated_info = reallocated_info
if not (len(nodes) == len(thunks) == len(post_thunk_clear)): if not (len(nodes) == len(thunks) == len(post_thunk_clear)):
raise ValueError() raise ValueError()
...@@ -222,6 +240,13 @@ class LoopGC(VM): ...@@ -222,6 +240,13 @@ class LoopGC(VM):
self.call_times[i] += t1 - t0 self.call_times[i] += t1 - t0
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
i += 1 i += 1
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
...@@ -234,6 +259,13 @@ class LoopGC(VM): ...@@ -234,6 +259,13 @@ class LoopGC(VM):
thunk() thunk()
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
...@@ -719,7 +751,8 @@ class VM_Linker(link.LocalLinker): ...@@ -719,7 +751,8 @@ class VM_Linker(link.LocalLinker):
post_thunk_clear, post_thunk_clear,
computed, computed,
compute_map, compute_map,
updated_vars updated_vars,
reallocated_info
): ):
pre_call_clear = [storage_map[v] for v in self.no_recycling] pre_call_clear = [storage_map[v] for v in self.no_recycling]
...@@ -863,12 +896,16 @@ class VM_Linker(link.LocalLinker): ...@@ -863,12 +896,16 @@ class VM_Linker(link.LocalLinker):
nodes, nodes,
thunks, thunks,
pre_call_clear, pre_call_clear,
post_thunk_clear) post_thunk_clear,
reallocated_info
)
else: else:
vm = Loop( vm = Loop(
nodes, nodes,
thunks, thunks,
pre_call_clear) pre_call_clear,
reallocated_info
)
else: else:
# Needed when allow_gc=True and profiling # Needed when allow_gc=True and profiling
deps = self.compute_gc_dependencies(storage_map) deps = self.compute_gc_dependencies(storage_map)
...@@ -895,8 +932,7 @@ class VM_Linker(link.LocalLinker): ...@@ -895,8 +932,7 @@ class VM_Linker(link.LocalLinker):
thunks = [] thunks = []
if self.allow_gc: # Collect Reallocation Info
viewed_by = {} viewed_by = {}
for var in fgraph.variables: for var in fgraph.variables:
viewed_by[var] = [] viewed_by[var] = []
...@@ -914,12 +950,14 @@ class VM_Linker(link.LocalLinker): ...@@ -914,12 +950,14 @@ class VM_Linker(link.LocalLinker):
ins = None ins = None
if dmap and idx_o in dmap: if dmap and idx_o in dmap:
idx_v = dmap[idx_o] idx_v = dmap[idx_o]
assert len(idx_v) == 1, "Here we only support the possibility to destroy one input" assert len(
idx_v) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[idx_v[0]] ins = node.inputs[idx_v[0]]
if vmap and idx_o in vmap: if vmap and idx_o in vmap:
assert ins is None assert ins is None
idx_v = vmap[idx_o] idx_v = vmap[idx_o]
assert len(idx_v) == 1, "Here we only support the possibility to view one input" assert len(
idx_v) == 1, "Here we only support the possibility to view one input"
ins = node.inputs[idx_v[0]] ins = node.inputs[idx_v[0]]
if ins: if ins:
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
...@@ -931,7 +969,8 @@ class VM_Linker(link.LocalLinker): ...@@ -931,7 +969,8 @@ class VM_Linker(link.LocalLinker):
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins]) assert not (ins in view_of and viewed_by[ins])
if (ins.ndim == 0 and not storage_map[ins][0] and ins not in fgraph.outputs and ins.owner): if (ins.ndim == 0 and not storage_map[ins][0] and ins not in fgraph.outputs and ins.owner):
# Constant Memory cannot be changed, Constant storage_map has a value here # Constant Memory cannot be changed, Constant storage_map
# has a value here
reuse_outs = [] reuse_outs = []
if ins not in view_of and not viewed_by.get(ins, []): if ins not in view_of and not viewed_by.get(ins, []):
# where gc # where gc
...@@ -940,25 +979,25 @@ class VM_Linker(link.LocalLinker): ...@@ -940,25 +979,25 @@ class VM_Linker(link.LocalLinker):
break break
for outs in order[i].outputs: for outs in order[i].outputs:
if outs.ndim == 0 and outs not in viewed_by.get(ins, []): if outs.ndim == 0 and outs not in viewed_by.get(ins, []):
reuse_outs.append(outs) reuse_outs.append(storage_map[outs])
elif ins in view_of: elif ins in view_of:
origin = view_of[ins] origin = view_of[ins]
viewed_by[origin].remove(ins) viewed_by[origin].remove(ins)
if (not viewed_by[origin] and if (not viewed_by[origin] and
origin not in fgraph.inputs and origin not in fgraph.inputs and
not isinstance(origin, theano.Constant)): not isinstance(origin, theano.Constant)):
#where gc # where gc
for i in range(idx + 1, len(order)): for i in range(idx + 1, len(order)):
if reuse_outs: if reuse_outs:
break break
for outs in order[i].outputs: for outs in order[i].outputs:
if outs.ndim == 0 and outs not in viewed_by.get(ins, []): if outs.ndim == 0 and outs not in viewed_by.get(ins, []):
reuse_outs.append(outs) reuse_outs.append(storage_map[outs])
if reuse_outs: if reuse_outs:
# if reusable output variable exists # if reusable output variable exists
reallocated_ins.add(ins) reallocated_ins.add(ins)
reallocated_info[ins] = reuse_outs reallocated_info[ins] = [storage_map[ins], reuse_outs]
for node in order: for node in order:
try: try:
...@@ -987,7 +1026,8 @@ class VM_Linker(link.LocalLinker): ...@@ -987,7 +1026,8 @@ class VM_Linker(link.LocalLinker):
for input in node.inputs: for input in node.inputs:
if ((input in computed) if ((input in computed)
and (input not in fgraph.outputs) and (input not in fgraph.outputs)
and (node == last_user[input])): and (node == last_user[input])
and input not in reallocated_info.keys()):
clear_after_this_thunk.append(storage_map[input]) clear_after_this_thunk.append(storage_map[input])
post_thunk_clear.append(clear_after_this_thunk) post_thunk_clear.append(clear_after_this_thunk)
else: else:
...@@ -998,7 +1038,8 @@ class VM_Linker(link.LocalLinker): ...@@ -998,7 +1038,8 @@ class VM_Linker(link.LocalLinker):
post_thunk_clear, post_thunk_clear,
computed, computed,
compute_map, compute_map,
self.updated_vars self.updated_vars,
reallocated_info
) )
vm.storage_map = storage_map vm.storage_map = storage_map
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论