提交 f42af175 authored 作者: Lijun Xue's avatar Lijun Xue

Some fixes

上级 e8097a09
...@@ -923,9 +923,9 @@ class VM_Linker(link.LocalLinker): ...@@ -923,9 +923,9 @@ class VM_Linker(link.LocalLinker):
thunks = [] thunks = []
# Collect Reallocation Info # Collect Reallocation Info
compute_map = defaultdict(lambda: [0]) compute_map_re = defaultdict(lambda: [0])
for var in fgraph.inputs: for var in fgraph.inputs:
compute_map[var][0] = 1 compute_map_re[var][0] = 1
viewed_by = {} viewed_by = {}
for var in fgraph.variables: for var in fgraph.variables:
...@@ -943,7 +943,7 @@ class VM_Linker(link.LocalLinker): ...@@ -943,7 +943,7 @@ class VM_Linker(link.LocalLinker):
idx_o = 0 idx_o = 0
for out in node.outputs: for out in node.outputs:
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 1 compute_map_re[var][0] = 1
ins = None ins = None
if dmap and idx_o in dmap: if dmap and idx_o in dmap:
idx_v = dmap[idx_o] idx_v = dmap[idx_o]
...@@ -968,7 +968,7 @@ class VM_Linker(link.LocalLinker): ...@@ -968,7 +968,7 @@ class VM_Linker(link.LocalLinker):
if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0] if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0]
and ins not in fgraph.outputs and ins.owner and ins not in fgraph.outputs and ins.owner
and dependencies.get(ins, None) and dependencies.get(ins, None)
and all([compute_map[v][0] for v in dependencies[ins]])): and all([compute_map_re[v][0] for v in dependencies.get(ins, []])])):
# Constant Memory cannot be changed, Constant storage_map # Constant Memory cannot be changed, Constant storage_map
# has a value here # has a value here
reuse_out = None reuse_out = None
...@@ -978,7 +978,7 @@ class VM_Linker(link.LocalLinker): ...@@ -978,7 +978,7 @@ class VM_Linker(link.LocalLinker):
if reuse_out: if reuse_out:
break break
for out in order[i].outputs: for out in order[i].outputs:
if out.ndim == 0 and out not in pre_allocated: if getattr(out, 'ndim', None) == 0 and out not in pre_allocated:
reuse_out = out reuse_out = out
pre_allocated.add(out) pre_allocated.add(out)
elif ins in view_of: elif ins in view_of:
...@@ -992,7 +992,7 @@ class VM_Linker(link.LocalLinker): ...@@ -992,7 +992,7 @@ class VM_Linker(link.LocalLinker):
if reuse_out: if reuse_out:
break break
for out in order[i].outputs: for out in order[i].outputs:
if out.ndim == 0 and out not in pre_allocated: if getattr(out, 'ndim', None) == 0 and out not in pre_allocated:
reuse_out = out reuse_out = out
pre_allocated.add(out) pre_allocated.add(out)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论