提交 ae591039 authored 作者: Frederic's avatar Frederic

Bugfix. We should only check the input that is viewed by a given output, not all inputs.

上级 c1ed99b1
...@@ -699,16 +699,25 @@ class ProfileStats(object): ...@@ -699,16 +699,25 @@ class ProfileStats(object):
# allocated by the node # allocated by the node
idx2 = 0 idx2 = 0
for out in node.outputs: for out in node.outputs:
if (dmap and idx2 in dmap) or (vmap and idx2 in vmap): ins = None
if dmap and idx2 in dmap:
vidx = dmap[idx2]
assert len(vidx) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[vidx[0]]
if vmap and idx2 in vmap:
assert ins is None
vidx = vmap[idx2]
assert len(vidx) == 1, "Here we only support the possibility to view one input"
ins = node.inputs[vidx[0]]
if ins is not None:
# This is needed for destroy_map in case it # This is needed for destroy_map in case it
# return a partial view that is destroyed. So # return a partial view that is destroyed. So
# the output could be different then the # the output could be different then the
# input. # input.
for ins in node.inputs: assert isinstance(ins, theano.Variable)
assert isinstance(ins, theano.Variable) # we keep trac of view only again the original
# we keep trac of view only again the original view_of[out] = view_of.get(ins, ins)
view_of[out] = view_of.get(ins, ins) viewed_by[ins].append(out)
viewed_by[ins].append(out)
else: else:
running_memory_size += var_mem[out] running_memory_size += var_mem[out]
node_memory_size += var_mem[out] node_memory_size += var_mem[out]
...@@ -826,16 +835,25 @@ class ProfileStats(object): ...@@ -826,16 +835,25 @@ class ProfileStats(object):
# Update the Python emulating dicts and add the # Update the Python emulating dicts and add the
# memory allocated by the node # memory allocated by the node
for out in node.outputs: for out in node.outputs:
if (dmap and idx in dmap) or (vmap and idx in vmap): ins = None
if dmap and idx in dmap:
vidx = dmap[idx]
assert len(vidx) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[vidx[0]]
if vmap and idx in vmap:
assert ins is None, "Here we only support the possibility to view one input"
vidx = vmap[idx]
assert len(vidx) == 1
ins = node.inputs[vidx[0]]
if ins is not None:
# This is needed for destroy_map in case it # This is needed for destroy_map in case it
# return a partial view that is destroyed. So # return a partial view that is destroyed. So
# the output could be different then the # the output could be different then the
# input. # input.
for ins in node.inputs: assert isinstance(ins, theano.Variable)
assert isinstance(ins, theano.Variable) # We keep trac of view only again the original
# We keep trac of view only again the original view_of_temp[out] = view_of_temp.get(ins, ins)
view_of_temp[out] = view_of_temp.get(ins, ins) viewed_by_temp[ins].append(out)
viewed_by_temp[ins].append(out)
else: else:
mem_created += var_mem[out] mem_created += var_mem[out]
idx += 1 idx += 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论