提交 3ddc19f5 authored 作者: Lijun Xue's avatar Lijun Xue

dependencies check and also speed up

上级 9f85a80e
......@@ -150,7 +150,11 @@ class VM(object):
profile.node_cleared_order = self.node_cleared_order[:]
if hasattr(self, 'dependencies'):
if self.dependencies:
profile.dependencies = self.dependencies.copy()
else:
profile.dependencies = {}
# clear the timer info out of the buffers
for i in xrange(len(self.call_times)):
......@@ -897,6 +901,7 @@ class VM_Linker(link.LocalLinker):
for var in fgraph.variables:
viewed_by[var] = []
view_of = {}
reallocated_info = {}
for idx in range(len(order)):
node = order[idx]
......@@ -925,13 +930,16 @@ class VM_Linker(link.LocalLinker):
for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins])
if (ins.ndim == 0 and storage_map[ins][0] and ins not in fgraph.outputs and ins.owner):
if (ins.ndim == 0 and not storage_map[ins][0] and ins not in fgraph.outputs and ins.owner):
# Constant Memory cannot be changed, Constant storage_map has a value here
reuse_outs = []
if ins not in view_of and not viewed_by.get(ins, []):
# where gc
for i in range(idx + 1, len(order)):
if reuse_outs:
break
for outs in order[i].outputs:
if outs.ndim == 0 and out not in viewed_by.values():
if outs.ndim == 0 and outs not in viewed_by.get(ins, []):
reuse_outs.append(outs)
elif ins in view_of:
origin = view_of[ins]
......@@ -941,13 +949,16 @@ class VM_Linker(link.LocalLinker):
not isinstance(origin, theano.Constant)):
#where gc
for i in range(idx + 1, len(order)):
if reuse_outs:
break
for outs in order[i].outputs:
if outs.ndim == 0 and out not in viewed_by.values():
if outs.ndim == 0 and outs not in viewed_by.get(ins, []):
reuse_outs.append(outs)
if reuse_outs:
# if reusable output variable exists
reallocated_ins.add(ins)
reallocated_info[ins] = reuse_outs
for node in order:
try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论