提交 954f5fce authored 作者: Lijun Xue's avatar Lijun Xue

wrap code in to a function

上级 c41c4307
...@@ -900,74 +900,57 @@ class VM_Linker(link.LocalLinker): ...@@ -900,74 +900,57 @@ class VM_Linker(link.LocalLinker):
for var in fgraph.inputs: for var in fgraph.inputs:
compute_map_re[var][0] = 1 compute_map_re[var][0] = 1
reallocated_info = {}
if getattr(fgraph.profile, 'dependencies', None): if getattr(fgraph.profile, 'dependencies', None):
dependencies = getattr(fgraph.profile, 'dependencies') dependencies = getattr(fgraph.profile, 'dependencies')
else: else:
dependencies = self.compute_gc_dependencies(storage_map) dependencies = self.compute_gc_dependencies(storage_map)
viewed_by = {} def calculate_reallocate_info(order, fgraph, dependencies):
for var in fgraph.variables: reallocated_info = {}
viewed_by[var] = [] viewed_by = {}
view_of = {} for var in fgraph.variables:
pre_allocated = set([]) viewed_by[var] = []
allocated = set([]) view_of = {}
for idx in range(len(order)): pre_allocated = set([])
node = order[idx] allocated = set([])
dmap = getattr(node.op, 'destroy_map', None) for idx in range(len(order)):
vmap = getattr(node.op, 'view_map', None) node = order[idx]
dmap = getattr(node.op, 'destroy_map', None)
idx_o = 0 vmap = getattr(node.op, 'view_map', None)
for out in node.outputs:
for var in node.outputs: idx_o = 0
compute_map_re[var][0] = 1 for out in node.outputs:
ins = None for var in node.outputs:
if dmap and idx_o in dmap: compute_map_re[var][0] = 1
idx_v = dmap[idx_o] ins = None
assert len( if dmap and idx_o in dmap:
idx_v) == 1, "Here we only support the possibility to destroy one input" idx_v = dmap[idx_o]
ins = node.inputs[idx_v[0]] assert len(
if vmap and idx_o in vmap: idx_v) == 1, "Here we only support the possibility to destroy one input"
assert ins is None ins = node.inputs[idx_v[0]]
idx_v = vmap[idx_o] if vmap and idx_o in vmap:
assert len( assert ins is None
idx_v) == 1, "Here we only support the possibility to view one input" idx_v = vmap[idx_o]
ins = node.inputs[idx_v[0]] assert len(
if ins is not None: idx_v) == 1, "Here we only support the possibility to view one input"
assert isinstance(ins, theano.Variable) ins = node.inputs[idx_v[0]]
origin = view_of.get(ins, ins) if ins is not None:
view_of[out] = origin assert isinstance(ins, theano.Variable)
viewed_by[origin].append(out) origin = view_of.get(ins, ins)
idx_o += 1 view_of[out] = origin
viewed_by[origin].append(out)
for ins in node.inputs: idx_o += 1
assert not (ins in view_of and viewed_by[ins])
if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0] for ins in node.inputs:
and ins not in fgraph.outputs and ins.owner assert not (ins in view_of and viewed_by[ins])
and all([compute_map_re[v][0] for v in dependencies.get(ins, [])]) if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0]
and ins not in allocated): and ins not in fgraph.outputs and ins.owner
# Constant Memory cannot be changed and all([compute_map_re[v][0] for v in dependencies.get(ins, [])])
# Constant and shared variables' storage_map value is not empty and ins not in allocated):
reuse_out = None # Constant Memory cannot be changed
if ins not in view_of and not viewed_by.get(ins, []): # Constant and shared variables' storage_map value is not empty
# where gc reuse_out = None
for i in range(idx + 1, len(order)): if ins not in view_of and not viewed_by.get(ins, []):
if reuse_out:
break
for out in order[i].outputs:
if (getattr(out, 'ndim', None) == 0 and out not in pre_allocated
and ins.type == out.type):
reuse_out = out
pre_allocated.add(out)
allocated.add(ins)
elif ins in view_of:
origin = view_of[ins]
if ins in viewed_by[origin]:
viewed_by[origin].remove(ins)
if (not viewed_by[origin] and
origin not in fgraph.inputs and
not isinstance(origin, theano.Constant)):
# where gc # where gc
for i in range(idx + 1, len(order)): for i in range(idx + 1, len(order)):
if reuse_out: if reuse_out:
...@@ -978,9 +961,30 @@ class VM_Linker(link.LocalLinker): ...@@ -978,9 +961,30 @@ class VM_Linker(link.LocalLinker):
reuse_out = out reuse_out = out
pre_allocated.add(out) pre_allocated.add(out)
allocated.add(ins) allocated.add(ins)
elif ins in view_of:
origin = view_of[ins]
if ins in viewed_by[origin]:
viewed_by[origin].remove(ins)
if (not viewed_by[origin] and
origin not in fgraph.inputs and
not isinstance(origin, theano.Constant)):
# where gc
for i in range(idx + 1, len(order)):
if reuse_out:
break
for out in order[i].outputs:
if (getattr(out, 'ndim', None) == 0 and out not in pre_allocated
and ins.type == out.type):
reuse_out = out
pre_allocated.add(out)
allocated.add(ins)
if reuse_out:
reallocated_info[ins] = [ins, reuse_out]
return reallocated_info
if reuse_out: reallocated_info = calculate_reallocate_info(order, fgraph, dependencies)
reallocated_info[ins] = [ins, reuse_out]
for node in order: for node in order:
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论