提交 d9c3b034 authored 作者: Roy Xue's avatar Roy Xue

Add ins type assertion

上级 01f1a753
...@@ -703,7 +703,7 @@ class ProfileStats(object): ...@@ -703,7 +703,7 @@ class ProfileStats(object):
# This is needed for destroy_map in case it return a partial view that is destroyed. # This is needed for destroy_map in case it return a partial view that is destroyed.
# So the output could be different then the input. # So the output could be different then the input.
for ins in node.inputs: for ins in node.inputs:
# assert len[ins] == 1 assert isinstance(ins, theano.Variable)
view_of[out] = view_of.get(ins, ins)# This get make that we keep trac of view only again the original view_of[out] = view_of.get(ins, ins)# This get make that we keep trac of view only again the original
viewed_by[ins].append(out) viewed_by[ins].append(out)
else: else:
...@@ -717,7 +717,7 @@ class ProfileStats(object): ...@@ -717,7 +717,7 @@ class ProfileStats(object):
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of and ins in viewed_by) assert not (ins in view_of and ins in viewed_by)
# we keep trac of the original var, so this shouldn't happen # we keep trac of the original var, so this shouldn't happen
if dependencies[ins] and ins not in fgraph.outputs: if dependencies[ins] and ins not in fgraph.outputs and ins.owner:
if ins not in view_of and not viewed_by.get(ins, []): if ins not in view_of and not viewed_by.get(ins, []):
running_memory_size -= var_mem[ins] running_memory_size -= var_mem[ins]
elif ins in view_of: elif ins in view_of:
...@@ -800,8 +800,6 @@ class ProfileStats(object): ...@@ -800,8 +800,6 @@ class ProfileStats(object):
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 1 compute_map[var][0] = 1
print view_of, viewed_by
mem_created = 0 mem_created = 0
mem_freed = 0 mem_freed = 0
max_storage = max_mem_count max_storage = max_mem_count
...@@ -816,6 +814,7 @@ class ProfileStats(object): ...@@ -816,6 +814,7 @@ class ProfileStats(object):
# This is needed for destroy_map in case it return a partial view that is destroyed. # This is needed for destroy_map in case it return a partial view that is destroyed.
# So the output could be different then the input. # So the output could be different then the input.
for ins in node.inputs: for ins in node.inputs:
assert isinstance(ins, theano.Variable)
view_of[out] = view_of.get(ins, ins)# This get make that we keep trac of view only again the original view_of[out] = view_of.get(ins, ins)# This get make that we keep trac of view only again the original
viewed_by[ins].append(out) viewed_by[ins].append(out)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论