提交 3f2c9223 authored 作者: Roy Xue's avatar Roy Xue 提交者: Lijun Xue

add basic logic.

上级 ae450fb0
...@@ -892,18 +892,60 @@ class VM_Linker(link.LocalLinker): ...@@ -892,18 +892,60 @@ class VM_Linker(link.LocalLinker):
thunks = [] thunks = []
if self.allow_gc: if self.allow_gc:
viewed_by = {}
for var in fgraph.variables:
viewed_by[var] = []
view_of = {}
dependencies = fgraph.profile.dependencies
for idx in range(len(order)): for idx in range(len(order)):
node = order[idx] node = order[idx]
dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None)
reallocated_ins = set([])
idx_o = 0
for out in node.output:
if dmap and idx_o in dmap:
idx_v = dmap[idx_o]
assert len(idx_v) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[idx_v[0]]
if vmap and idx_o in vmap:
assert ins is None
idx_v = vmap[idx_o]
assert len(idx_v) == 1, "Here we only support the possibility to view one input"
ins = node.inputs[idx_v[0]]
if ins:
assert isinstance(ins, theano.Variable)
origin = view_of.get(ins, ins)
view_of[out] = origin
viewed_by[origin].append(out)
idx_o += 1
for ins in node.inputs: for ins in node.inputs:
if ins.ndim == 0 and storage_map[ins][0]: assert not (ins in view_of and viewed_by[ins])
# check if input variable ndim = 0 if (ins.ndim == 0 and storage_map[ins][0] and ins not in fgraph.outputs and ins.owner and dependencies[ins]):
reuse_outs = []
if ins not in view_of and not viewed_by.get(ins, []):
# where gc
for i in range(idx + 1, len(order)): for i in range(idx + 1, len(order)):
for outs in order[i].outputs: for outs in order[i].outputs:
if outs.ndim == 0: if outs.ndim == 0:
storage_map[outs] = storage_map[ins] reuse_list.append(outs)
break reallocated_ins.add(ins)
else:continue elif ins in view_of:
break origin = view_of[ins]
viewed_by[origin].remove(ins)
if (not viewed_by[origin] and
origin not in fgraph.inputs and
not isinstance(origin, theano.Constant)):
#where gc
for i in range(idx + 1, len(order)):
for outs in order[i].outputs:
if outs.ndim == 0:
reuse_list.append(outs)
reallocated_ins.add(ins)
for node in order: for node in order:
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论