提交 ae591039 authored 作者: Frederic's avatar Frederic

Bugfix. We should only check the input that is viewed by a given output, not all inputs.

上级 c1ed99b1
...@@ -699,12 +699,21 @@ class ProfileStats(object): ...@@ -699,12 +699,21 @@ class ProfileStats(object):
# allocated by the node # allocated by the node
idx2 = 0 idx2 = 0
for out in node.outputs: for out in node.outputs:
if (dmap and idx2 in dmap) or (vmap and idx2 in vmap): ins = None
if dmap and idx2 in dmap:
vidx = dmap[idx2]
assert len(vidx) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[vidx[0]]
if vmap and idx2 in vmap:
assert ins is None
vidx = vmap[idx2]
assert len(vidx) == 1, "Here we only support the possibility to view one input"
ins = node.inputs[vidx[0]]
if ins is not None:
# This is needed for destroy_map in case it # This is needed for destroy_map in case it
# return a partial view that is destroyed. So # return a partial view that is destroyed. So
# the output could be different then the # the output could be different then the
# input. # input.
for ins in node.inputs:
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
# we keep trac of view only again the original # we keep trac of view only again the original
view_of[out] = view_of.get(ins, ins) view_of[out] = view_of.get(ins, ins)
...@@ -826,12 +835,21 @@ class ProfileStats(object): ...@@ -826,12 +835,21 @@ class ProfileStats(object):
# Update the Python emulating dicts and add the # Update the Python emulating dicts and add the
# memory allocated by the node # memory allocated by the node
for out in node.outputs: for out in node.outputs:
if (dmap and idx in dmap) or (vmap and idx in vmap): ins = None
if dmap and idx in dmap:
vidx = dmap[idx]
assert len(vidx) == 1, "Here we only support the possibility to destroy one input"
ins = node.inputs[vidx[0]]
if vmap and idx in vmap:
assert ins is None, "Here we only support the possibility to view one input"
vidx = vmap[idx]
assert len(vidx) == 1
ins = node.inputs[vidx[0]]
if ins is not None:
# This is needed for destroy_map in case it # This is needed for destroy_map in case it
# return a partial view that is destroyed. So # return a partial view that is destroyed. So
# the output could be different then the # the output could be different then the
# input. # input.
for ins in node.inputs:
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
# We keep trac of view only again the original # We keep trac of view only again the original
view_of_temp[out] = view_of_temp.get(ins, ins) view_of_temp[out] = view_of_temp.get(ins, ins)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论