提交 e8097a09 authored 作者: Lijun Xue's avatar Lijun Xue

Fix Error and make some changes

上级 c9c8a418
......@@ -17,6 +17,8 @@ from theano.configparser import (config, AddConfigVar,
import theano.gof.cmodule
from theano.compat.python2x import defaultdict
logger = logging.getLogger(__name__)
AddConfigVar('profile',
......@@ -187,10 +189,7 @@ class Loop(VM):
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
info[1][0] = info[0][0]
except:
link.raise_with_op(node, thunk)
else:
......@@ -202,10 +201,7 @@ class Loop(VM):
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
info[1][0] = info[0][0]
except:
link.raise_with_op(node, thunk)
......@@ -243,10 +239,7 @@ class LoopGC(VM):
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
info[1][0] = info[0][0]
i += 1
except:
link.raise_with_op(node, thunk)
......@@ -262,10 +255,7 @@ class LoopGC(VM):
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
for outs in info[1]:
if not outs[0]:
outs[0] = info[1][0]
break
info[1][0] = info[0][0]
except:
link.raise_with_op(node, thunk)
......@@ -933,21 +923,27 @@ class VM_Linker(link.LocalLinker):
thunks = []
# Collect Reallocation Info
compute_map = defaultdict(lambda: [0])
for var in fgraph.inputs:
compute_map[var][0] = 1
viewed_by = {}
for var in fgraph.variables:
viewed_by[var] = []
view_of = {}
reallocated_info = {}
dependencies = getattr(fgraph.profile, 'dependencies', {})
pre_allocated = set([])
for idx in range(len(order)):
node = order[idx]
dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None)
reallocated_ins = set([])
idx_o = 0
for out in node.outputs:
for var in node.outputs:
compute_map[var][0] = 1
ins = None
if dmap and idx_o in dmap:
idx_v = dmap[idx_o]
......@@ -960,7 +956,7 @@ class VM_Linker(link.LocalLinker):
assert len(
idx_v) == 1, "Here we only support the possibility to view one input"
ins = node.inputs[idx_v[0]]
if ins:
if ins is not None:
assert isinstance(ins, theano.Variable)
origin = view_of.get(ins, ins)
view_of[out] = origin
......@@ -971,18 +967,20 @@ class VM_Linker(link.LocalLinker):
assert not (ins in view_of and viewed_by[ins])
if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0]
and ins not in fgraph.outputs and ins.owner
and dependencies.get(ins, None)):
and dependencies.get(ins, None)
and all([compute_map[v][0] for v in dependencies[ins]])):
# Constant Memory cannot be changed, Constant storage_map
# has a value here
reuse_outs = []
reuse_out = None
if ins not in view_of and not viewed_by.get(ins, []):
# where gc
for i in range(idx + 1, len(order)):
if reuse_outs:
if reuse_out:
break
for outs in order[i].outputs:
if outs.ndim == 0 and outs not in viewed_by.get(ins, []):
reuse_outs.append(storage_map[outs])
for out in order[i].outputs:
if out.ndim == 0 and out not in pre_allocated:
reuse_out = out
pre_allocated.add(out)
elif ins in view_of:
origin = view_of[ins]
viewed_by[origin].remove(ins)
......@@ -991,16 +989,15 @@ class VM_Linker(link.LocalLinker):
not isinstance(origin, theano.Constant)):
# where gc
for i in range(idx + 1, len(order)):
if reuse_outs:
if reuse_out:
break
for outs in order[i].outputs:
if outs.ndim == 0 and outs not in viewed_by.get(ins, []):
reuse_outs.append(storage_map[outs])
for out in order[i].outputs:
if out.ndim == 0 and out not in pre_allocated:
reuse_out = out
pre_allocated.add(out)
if reuse_outs:
# if reusable output variable exists
reallocated_ins.add(ins)
reallocated_info[ins] = [storage_map[ins], reuse_outs]
reallocated_info[ins] = [storage_map[ins], storage_map[reuse_out]]
for node in order:
try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论