提交 b823ef7f authored 作者: Frederic's avatar Frederic

Compute total memory size in the storage_map

上级 0e94f1d7
......@@ -185,13 +185,15 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
item for item in node.fgraph.inputs
if not isinstance(item, theano.compile.SharedVariable)]
storage_map_list = []
total_size = 0
total_size_inputs = 0
for k in storage_map.keys():
storage_map_item = []
# storage_map_item[0]
# storage_map_item[0]: the variable
storage_map_item.append(str(k))
# storage_map_item[1]
# storage_map_item[1]: the shape
shapeinfo = None
if hasattr(storage_map[k][0], 'shape'):
shapeinfo = storage_map[k][0].shape
......@@ -202,15 +204,41 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
else:
storage_map_item.append(None)
# storage_map_item[2]
# storage_map_item[3]
# storage_map_item[2]: itemsize
# storage_map_item[3]: bytes
if hasattr(storage_map[k][0], 'dtype'):
dtype = storage_map[k][0].dtype
storage_map_item.append(numpy.dtype(dtype).itemsize)
if shapeinfo is None:
storage_map_item.append(None)
else:
storage_map_item.append(numpy.dtype(dtype).itemsize * numpy.prod(shapeinfo))
sz = numpy.dtype(dtype).itemsize * numpy.prod(shapeinfo)
storage_map_item.append(sz)
total_size += sz
if not k.owner:
total_size_inputs += sz
else:
# If it is a view, don't count it twice.
if getattr(k.owner.op, 'view_map', None):
vmap = k.owner.op.view_map
out_idx = k.owner.outputs.index(k)
data = storage_map[k][0]
if out_idx in vmap:
assert len(vmap[out_idx]) == 1
input_data = storage_map[k.owner.inputs[vmap[out_idx][0]]][0]
if k.type.may_share_memory(data, input_data):
total_size -= sz
# If it is a destroyed input, the input shouldn't be in the storage_map anymore
# except if there is a special flag used. So we still must check it.
if getattr(k.owner.op, 'destroy_map', None):
vmap = k.owner.op.destroy_map
out_idx = k.owner.outputs.index(k)
data = storage_map[k][0]
if out_idx in vmap:
assert len(vmap[out_idx]) == 1
input_data = storage_map[k.owner.inputs[vmap[out_idx][0]]][0]
if k.type.may_share_memory(data, input_data):
total_size -= sz
else:
bytes = getsizeof(storage_map[k][0])
storage_map_item.append(bytes)
......@@ -243,6 +271,10 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
detailed_err_msg += ", TotalSize: %s Byte(s)\n" % storage_map_item[3]
else:
detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % (
total_size, total_size/1024./1024/1024)
detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f BG\n" % (
total_size_inputs, total_size_inputs/1024./1024/1024)
else:
hints.append(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论