提交 f50aa790 authored 作者: Lijun Xue's avatar Lijun Xue

Update Logic

上级 255792af
...@@ -170,10 +170,6 @@ class Loop(VM): ...@@ -170,10 +170,6 @@ class Loop(VM):
No garbage collection is allowed on intermediate results. No garbage collection is allowed on intermediate results.
""" """
def __init__(self, nodes, thunks, pre_call_clear, reallocated_info):
super(Loop, self).__init__(nodes, thunks, pre_call_clear)
self.reallocated_info = reallocated_info
def __call__(self): def __call__(self):
if self.time_thunks: if self.time_thunks:
for cont in self.pre_call_clear: for cont in self.pre_call_clear:
...@@ -186,10 +182,6 @@ class Loop(VM): ...@@ -186,10 +182,6 @@ class Loop(VM):
t1 = time.time() t1 = time.time()
self.call_counts[i] += 1 self.call_counts[i] += 1
self.call_times[i] += t1 - t0 self.call_times[i] += t1 - t0
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
info[1][0] = info[0][0]
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
else: else:
...@@ -198,10 +190,6 @@ class Loop(VM): ...@@ -198,10 +190,6 @@ class Loop(VM):
try: try:
for thunk, node in zip(self.thunks, self.nodes): for thunk, node in zip(self.thunks, self.nodes):
thunk() thunk()
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
info[1][0] = info[0][0]
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
...@@ -213,10 +201,9 @@ class LoopGC(VM): ...@@ -213,10 +201,9 @@ class LoopGC(VM):
Garbage collection is possible on intermediate results. Garbage collection is possible on intermediate results.
""" """
def __init__(self, nodes, thunks, pre_call_clear, post_thunk_clear, reallocated_info): def __init__(self, nodes, thunks, pre_call_clear, post_thunk_clear):
super(LoopGC, self).__init__(nodes, thunks, pre_call_clear) super(LoopGC, self).__init__(nodes, thunks, pre_call_clear)
self.post_thunk_clear = post_thunk_clear self.post_thunk_clear = post_thunk_clear
self.reallocated_info = reallocated_info
if not (len(nodes) == len(thunks) == len(post_thunk_clear)): if not (len(nodes) == len(thunks) == len(post_thunk_clear)):
raise ValueError() raise ValueError()
...@@ -236,10 +223,6 @@ class LoopGC(VM): ...@@ -236,10 +223,6 @@ class LoopGC(VM):
self.call_times[i] += t1 - t0 self.call_times[i] += t1 - t0
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
info[1][0] = info[0][0]
i += 1 i += 1
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
...@@ -252,10 +235,6 @@ class LoopGC(VM): ...@@ -252,10 +235,6 @@ class LoopGC(VM):
thunk() thunk()
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
for ins in node.inputs:
info = self.reallocated_info.get(ins, None)
if info:
info[1][0] = info[0][0]
except: except:
link.raise_with_op(node, thunk) link.raise_with_op(node, thunk)
...@@ -742,7 +721,6 @@ class VM_Linker(link.LocalLinker): ...@@ -742,7 +721,6 @@ class VM_Linker(link.LocalLinker):
computed, computed,
compute_map, compute_map,
updated_vars, updated_vars,
reallocated_info
): ):
pre_call_clear = [storage_map[v] for v in self.no_recycling] pre_call_clear = [storage_map[v] for v in self.no_recycling]
...@@ -887,14 +865,12 @@ class VM_Linker(link.LocalLinker): ...@@ -887,14 +865,12 @@ class VM_Linker(link.LocalLinker):
thunks, thunks,
pre_call_clear, pre_call_clear,
post_thunk_clear, post_thunk_clear,
reallocated_info
) )
else: else:
vm = Loop( vm = Loop(
nodes, nodes,
thunks, thunks,
pre_call_clear, pre_call_clear,
reallocated_info
) )
else: else:
# Needed when allow_gc=True and profiling # Needed when allow_gc=True and profiling
...@@ -997,10 +973,10 @@ class VM_Linker(link.LocalLinker): ...@@ -997,10 +973,10 @@ class VM_Linker(link.LocalLinker):
pre_allocated.add(out) pre_allocated.add(out)
if reuse_out: if reuse_out:
reallocated_info[ins] = [storage_map[ins], storage_map[reuse_out]] reallocated_info[ins] = [ins, reuse_out]
for pair in reallocated_info.values(): for pair in reallocated_info.values():
pair[1] = pair[0] storage_map[pair[1]] = storage_map[pair[0]]
for node in order: for node in order:
try: try:
...@@ -1042,7 +1018,6 @@ class VM_Linker(link.LocalLinker): ...@@ -1042,7 +1018,6 @@ class VM_Linker(link.LocalLinker):
computed, computed,
compute_map, compute_map,
self.updated_vars, self.updated_vars,
reallocated_info
) )
vm.storage_map = storage_map vm.storage_map = storage_map
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论