提交 19223fd6 authored 作者: Lijun Xue's avatar Lijun Xue

reallocation in Loop/LoopGC

上级 3ddc19f5
......@@ -155,7 +155,6 @@ class VM(object):
else:
profile.dependencies = {}
# clear the timer info out of the buffers
for i in xrange(len(self.call_times)):
self.call_times[i] = 0.0
......@@ -169,6 +168,10 @@ class Loop(VM):
No garbage collection is allowed on intermediate results.
"""
def __init__(self, nodes, thunks, pre_call_clear, reallocated_info):
super(LoopGC, self).__init__(nodes, thunks, pre_call_clear)
self.reallocated_info = reallocated_info
def __call__(self):
if self.time_thunks:
for cont in self.pre_call_clear:
......@@ -181,6 +184,13 @@ class Loop(VM):
t1 = time.time()
self.call_counts[i] += 1
self.call_times[i] += t1 - t0
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
except:
link.raise_with_op(node, thunk)
else:
......@@ -189,6 +199,13 @@ class Loop(VM):
try:
for thunk, node in zip(self.thunks, self.nodes):
thunk()
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
except:
link.raise_with_op(node, thunk)
......@@ -200,9 +217,10 @@ class LoopGC(VM):
Garbage collection is possible on intermediate results.
"""
def __init__(self, nodes, thunks, pre_call_clear, post_thunk_clear):
def __init__(self, nodes, thunks, pre_call_clear, post_thunk_clear, reallocated_info):
super(LoopGC, self).__init__(nodes, thunks, pre_call_clear)
self.post_thunk_clear = post_thunk_clear
self.reallocated_info = reallocated_info
if not (len(nodes) == len(thunks) == len(post_thunk_clear)):
raise ValueError()
......@@ -222,6 +240,13 @@ class LoopGC(VM):
self.call_times[i] += t1 - t0
for old_s in old_storage:
old_s[0] = None
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
i += 1
except:
link.raise_with_op(node, thunk)
......@@ -234,6 +259,13 @@ class LoopGC(VM):
thunk()
for old_s in old_storage:
old_s[0] = None
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
except:
link.raise_with_op(node, thunk)
......@@ -719,7 +751,8 @@ class VM_Linker(link.LocalLinker):
post_thunk_clear,
computed,
compute_map,
updated_vars
updated_vars,
reallocated_info
):
pre_call_clear = [storage_map[v] for v in self.no_recycling]
......@@ -863,12 +896,16 @@ class VM_Linker(link.LocalLinker):
nodes,
thunks,
pre_call_clear,
post_thunk_clear)
post_thunk_clear,
reallocated_info
)
else:
vm = Loop(
nodes,
thunks,
pre_call_clear)
pre_call_clear,
reallocated_info
)
else:
# Needed when allow_gc=True and profiling
deps = self.compute_gc_dependencies(storage_map)
......@@ -895,70 +932,72 @@ class VM_Linker(link.LocalLinker):
thunks = []
if self.allow_gc:
viewed_by = {}
for var in fgraph.variables:
viewed_by[var] = []
view_of = {}
reallocated_info = {}
for idx in range(len(order)):
node = order[idx]
dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None)
reallocated_ins = set([])
idx_o = 0
for out in node.outputs:
ins = None
if dmap and idx_o in dmap:
idx_v = dmap[idx_o]
assert len(idx_v) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[idx_v[0]]
if vmap and idx_o in vmap:
assert ins is None
idx_v = vmap[idx_o]
assert len(idx_v) == 1, "Here we only support the possibility to view one input"
ins = node.inputs[idx_v[0]]
if ins:
assert isinstance(ins, theano.Variable)
origin = view_of.get(ins, ins)
view_of[out] = origin
viewed_by[origin].append(out)
idx_o += 1
for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins])
if (ins.ndim == 0 and not storage_map[ins][0] and ins not in fgraph.outputs and ins.owner):
# Constant Memory cannot be changed, Constant storage_map has a value here
reuse_outs = []
if ins not in view_of and not viewed_by.get(ins, []):
# Collect Reallocation Info
viewed_by = {}
for var in fgraph.variables:
viewed_by[var] = []
view_of = {}
reallocated_info = {}
for idx in range(len(order)):
node = order[idx]
dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None)
reallocated_ins = set([])
idx_o = 0
for out in node.outputs:
ins = None
if dmap and idx_o in dmap:
idx_v = dmap[idx_o]
assert len(
idx_v) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[idx_v[0]]
if vmap and idx_o in vmap:
assert ins is None
idx_v = vmap[idx_o]
assert len(
idx_v) == 1, "Here we only support the possibility to view one input"
ins = node.inputs[idx_v[0]]
if ins:
assert isinstance(ins, theano.Variable)
origin = view_of.get(ins, ins)
view_of[out] = origin
viewed_by[origin].append(out)
idx_o += 1
for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins])
if (ins.ndim == 0 and not storage_map[ins][0] and ins not in fgraph.outputs and ins.owner):
# Constant Memory cannot be changed, Constant storage_map
# has a value here
reuse_outs = []
if ins not in view_of and not viewed_by.get(ins, []):
# where gc
for i in range(idx + 1, len(order)):
if reuse_outs:
break
for outs in order[i].outputs:
if outs.ndim == 0 and outs not in viewed_by.get(ins, []):
reuse_outs.append(storage_map[outs])
elif ins in view_of:
origin = view_of[ins]
viewed_by[origin].remove(ins)
if (not viewed_by[origin] and
origin not in fgraph.inputs and
not isinstance(origin, theano.Constant)):
# where gc
for i in range(idx + 1, len(order)):
if reuse_outs:
break
for outs in order[i].outputs:
if outs.ndim == 0 and outs not in viewed_by.get(ins, []):
reuse_outs.append(outs)
elif ins in view_of:
origin = view_of[ins]
viewed_by[origin].remove(ins)
if (not viewed_by[origin] and
origin not in fgraph.inputs and
not isinstance(origin, theano.Constant)):
#where gc
for i in range(idx + 1, len(order)):
if reuse_outs:
break
for outs in order[i].outputs:
if outs.ndim == 0 and outs not in viewed_by.get(ins, []):
reuse_outs.append(outs)
reuse_outs.append(storage_map[outs])
if reuse_outs:
# if reusable output variable exists
reallocated_ins.add(ins)
reallocated_info[ins] = reuse_outs
if reuse_outs:
# if reusable output variable exists
reallocated_ins.add(ins)
reallocated_info[ins] = [storage_map[ins], reuse_outs]
for node in order:
try:
......@@ -987,7 +1026,8 @@ class VM_Linker(link.LocalLinker):
for input in node.inputs:
if ((input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])):
and (node == last_user[input])
and input not in reallocated_info.keys()):
clear_after_this_thunk.append(storage_map[input])
post_thunk_clear.append(clear_after_this_thunk)
else:
......@@ -998,7 +1038,8 @@ class VM_Linker(link.LocalLinker):
post_thunk_clear,
computed,
compute_map,
self.updated_vars
self.updated_vars,
reallocated_info
)
vm.storage_map = storage_map
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论