提交 ae591039 authored 作者: Frederic's avatar Frederic

Bugfix. We should only check the input that is viewed by a given output, not all inputs.

上级 c1ed99b1
......@@ -699,16 +699,25 @@ class ProfileStats(object):
# allocated by the node
idx2 = 0
for out in node.outputs:
if (dmap and idx2 in dmap) or (vmap and idx2 in vmap):
ins = None
if dmap and idx2 in dmap:
vidx = dmap[idx2]
assert len(vidx) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[vidx[0]]
if vmap and idx2 in vmap:
assert ins is None
vidx = vmap[idx2]
assert len(vidx) == 1, "Here we only support the possibility to view one input"
ins = node.inputs[vidx[0]]
if ins is not None:
# This is needed for destroy_map in case it
# return a partial view that is destroyed. So
# the output could be different then the
# input.
for ins in node.inputs:
assert isinstance(ins, theano.Variable)
# we keep trac of view only again the original
view_of[out] = view_of.get(ins, ins)
viewed_by[ins].append(out)
assert isinstance(ins, theano.Variable)
# we keep trac of view only again the original
view_of[out] = view_of.get(ins, ins)
viewed_by[ins].append(out)
else:
running_memory_size += var_mem[out]
node_memory_size += var_mem[out]
......@@ -826,16 +835,25 @@ class ProfileStats(object):
# Update the Python emulating dicts and add the
# memory allocated by the node
for out in node.outputs:
if (dmap and idx in dmap) or (vmap and idx in vmap):
ins = None
if dmap and idx in dmap:
vidx = dmap[idx]
assert len(vidx) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[vidx[0]]
if vmap and idx in vmap:
assert ins is None, "Here we only support the possibility to view one input"
vidx = vmap[idx]
assert len(vidx) == 1
ins = node.inputs[vidx[0]]
if ins is not None:
# This is needed for destroy_map in case it
# return a partial view that is destroyed. So
# the output could be different then the
# input.
for ins in node.inputs:
assert isinstance(ins, theano.Variable)
# We keep trac of view only again the original
view_of_temp[out] = view_of_temp.get(ins, ins)
viewed_by_temp[ins].append(out)
assert isinstance(ins, theano.Variable)
# We keep trac of view only again the original
view_of_temp[out] = view_of_temp.get(ins, ins)
viewed_by_temp[ins].append(out)
else:
mem_created += var_mem[out]
idx += 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论