提交 e8097a09 authored 作者: Lijun Xue's avatar Lijun Xue

Fix Error and make some changes

上级 c9c8a418
...@@ -17,6 +17,8 @@ from theano.configparser import (config, AddConfigVar, ...@@ -17,6 +17,8 @@ from theano.configparser import (config, AddConfigVar,
import theano.gof.cmodule import theano.gof.cmodule
from theano.compat.python2x import defaultdict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
AddConfigVar('profile', AddConfigVar('profile',
...@@ -187,10 +189,7 @@ class Loop(VM): ...@@ -187,10 +189,7 @@ class Loop(VM):
for ins in node.inputs: for ins in node.inputs:
info = self.reallocated_info.get(ins, None) info = self.reallocated_info.get(ins, None)
if info: if info:
for outs in info[1]: info[1][0] = info[0][0]
if not outs[0]:
outs[0] = info[1][0]
break
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
else: else:
...@@ -202,10 +201,7 @@ class Loop(VM): ...@@ -202,10 +201,7 @@ class Loop(VM):
for ins in node.inputs: for ins in node.inputs:
info = self.reallocated_info.get(ins, None) info = self.reallocated_info.get(ins, None)
if info: if info:
for outs in info[1]: info[1][0] = info[0][0]
if not outs[0]:
outs[0] = info[1][0]
break
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
...@@ -243,10 +239,7 @@ class LoopGC(VM): ...@@ -243,10 +239,7 @@ class LoopGC(VM):
for ins in node.inputs: for ins in node.inputs:
info = self.reallocated_info.get(ins, None) info = self.reallocated_info.get(ins, None)
if info: if info:
for outs in info[1]: info[1][0] = info[0][0]
if not outs[0]:
outs[0] = info[1][0]
break
i += 1 i += 1
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
...@@ -262,10 +255,7 @@ class LoopGC(VM): ...@@ -262,10 +255,7 @@ class LoopGC(VM):
for ins in node.inputs: for ins in node.inputs:
info = self.reallocated_info.get(ins, None) info = self.reallocated_info.get(ins, None)
if info: if info:
for outs in info[1]: info[1][0] = info[0][0]
if not outs[0]:
outs[0] = info[1][0]
break
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
...@@ -933,21 +923,27 @@ class VM_Linker(link.LocalLinker): ...@@ -933,21 +923,27 @@ class VM_Linker(link.LocalLinker):
thunks = [] thunks = []
# Collect Reallocation Info # Collect Reallocation Info
compute_map = defaultdict(lambda: [0])
for var in fgraph.inputs:
compute_map[var][0] = 1
viewed_by = {} viewed_by = {}
for var in fgraph.variables: for var in fgraph.variables:
viewed_by[var] = [] viewed_by[var] = []
view_of = {} view_of = {}
reallocated_info = {} reallocated_info = {}
dependencies = getattr(fgraph.profile, 'dependencies', {}) dependencies = getattr(fgraph.profile, 'dependencies', {})
pre_allocated = set([])
for idx in range(len(order)): for idx in range(len(order)):
node = order[idx] node = order[idx]
dmap = getattr(node.op, 'destroy_map', None) dmap = getattr(node.op, 'destroy_map', None)
vmap = getattr(node.op, 'view_map', None) vmap = getattr(node.op, 'view_map', None)
reallocated_ins = set([])
idx_o = 0 idx_o = 0
for out in node.outputs: for out in node.outputs:
for var in node.outputs:
compute_map[var][0] = 1
ins = None ins = None
if dmap and idx_o in dmap: if dmap and idx_o in dmap:
idx_v = dmap[idx_o] idx_v = dmap[idx_o]
...@@ -960,7 +956,7 @@ class VM_Linker(link.LocalLinker): ...@@ -960,7 +956,7 @@ class VM_Linker(link.LocalLinker):
assert len( assert len(
idx_v) == 1, "Here we only support the possibility to view one input" idx_v) == 1, "Here we only support the possibility to view one input"
ins = node.inputs[idx_v[0]] ins = node.inputs[idx_v[0]]
if ins: if ins is not None:
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
origin = view_of.get(ins, ins) origin = view_of.get(ins, ins)
view_of[out] = origin view_of[out] = origin
...@@ -971,18 +967,20 @@ class VM_Linker(link.LocalLinker): ...@@ -971,18 +967,20 @@ class VM_Linker(link.LocalLinker):
assert not (ins in view_of and viewed_by[ins]) assert not (ins in view_of and viewed_by[ins])
if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0] if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0]
and ins not in fgraph.outputs and ins.owner and ins not in fgraph.outputs and ins.owner
and dependencies.get(ins, None)): and dependencies.get(ins, None)
and all([compute_map[v][0] for v in dependencies[ins]])):
# Constant Memory cannot be changed, Constant storage_map # Constant Memory cannot be changed, Constant storage_map
# has a value here # has a value here
reuse_outs = [] reuse_out = None
if ins not in view_of and not viewed_by.get(ins, []): if ins not in view_of and not viewed_by.get(ins, []):
# where gc # where gc
for i in range(idx + 1, len(order)): for i in range(idx + 1, len(order)):
if reuse_outs: if reuse_out:
break break
for outs in order[i].outputs: for out in order[i].outputs:
if outs.ndim == 0 and outs not in viewed_by.get(ins, []): if out.ndim == 0 and out not in pre_allocated:
reuse_outs.append(storage_map[outs]) reuse_out = out
pre_allocated.add(out)
elif ins in view_of: elif ins in view_of:
origin = view_of[ins] origin = view_of[ins]
viewed_by[origin].remove(ins) viewed_by[origin].remove(ins)
...@@ -991,16 +989,15 @@ class VM_Linker(link.LocalLinker): ...@@ -991,16 +989,15 @@ class VM_Linker(link.LocalLinker):
not isinstance(origin, theano.Constant)): not isinstance(origin, theano.Constant)):
# where gc # where gc
for i in range(idx + 1, len(order)): for i in range(idx + 1, len(order)):
if reuse_outs: if reuse_out:
break break
for outs in order[i].outputs: for out in order[i].outputs:
if outs.ndim == 0 and outs not in viewed_by.get(ins, []): if out.ndim == 0 and out not in pre_allocated:
reuse_outs.append(storage_map[outs]) reuse_out = out
pre_allocated.add(out)
if reuse_outs: if reuse_outs:
# if reusable output variable exists reallocated_info[ins] = [storage_map[ins], storage_map[reuse_out]]
reallocated_ins.add(ins)
reallocated_info[ins] = [storage_map[ins], reuse_outs]
for node in order: for node in order:
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论