提交 b823ef7f authored 作者: Frederic's avatar Frederic

Compute total memory size in the storage_map

上级 0e94f1d7
...@@ -185,13 +185,15 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -185,13 +185,15 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
item for item in node.fgraph.inputs item for item in node.fgraph.inputs
if not isinstance(item, theano.compile.SharedVariable)] if not isinstance(item, theano.compile.SharedVariable)]
storage_map_list = [] storage_map_list = []
total_size = 0
total_size_inputs = 0
for k in storage_map.keys(): for k in storage_map.keys():
storage_map_item = [] storage_map_item = []
# storage_map_item[0] # storage_map_item[0]: the variable
storage_map_item.append(str(k)) storage_map_item.append(str(k))
# storage_map_item[1] # storage_map_item[1]: the shape
shapeinfo = None shapeinfo = None
if hasattr(storage_map[k][0], 'shape'): if hasattr(storage_map[k][0], 'shape'):
shapeinfo = storage_map[k][0].shape shapeinfo = storage_map[k][0].shape
...@@ -202,15 +204,41 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -202,15 +204,41 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
else: else:
storage_map_item.append(None) storage_map_item.append(None)
# storage_map_item[2] # storage_map_item[2]: itemsize
# storage_map_item[3] # storage_map_item[3]: bytes
if hasattr(storage_map[k][0], 'dtype'): if hasattr(storage_map[k][0], 'dtype'):
dtype = storage_map[k][0].dtype dtype = storage_map[k][0].dtype
storage_map_item.append(numpy.dtype(dtype).itemsize) storage_map_item.append(numpy.dtype(dtype).itemsize)
if shapeinfo is None: if shapeinfo is None:
storage_map_item.append(None) storage_map_item.append(None)
else: else:
storage_map_item.append(numpy.dtype(dtype).itemsize * numpy.prod(shapeinfo)) sz = numpy.dtype(dtype).itemsize * numpy.prod(shapeinfo)
storage_map_item.append(sz)
total_size += sz
if not k.owner:
total_size_inputs += sz
else:
# If it is a view, don't count it twice.
if getattr(k.owner.op, 'view_map', None):
vmap = k.owner.op.view_map
out_idx = k.owner.outputs.index(k)
data = storage_map[k][0]
if out_idx in vmap:
assert len(vmap[out_idx]) == 1
input_data = storage_map[k.owner.inputs[vmap[out_idx][0]]][0]
if k.type.may_share_memory(data, input_data):
total_size -= sz
# If it is a destroyed input, the input shouldn't be in the storage_map anymore
# except if there is a special flag used. So we still must check it.
if getattr(k.owner.op, 'destroy_map', None):
vmap = k.owner.op.destroy_map
out_idx = k.owner.outputs.index(k)
data = storage_map[k][0]
if out_idx in vmap:
assert len(vmap[out_idx]) == 1
input_data = storage_map[k.owner.inputs[vmap[out_idx][0]]][0]
if k.type.may_share_memory(data, input_data):
total_size -= sz
else: else:
bytes = getsizeof(storage_map[k][0]) bytes = getsizeof(storage_map[k][0])
storage_map_item.append(bytes) storage_map_item.append(bytes)
...@@ -243,6 +271,10 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -243,6 +271,10 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
detailed_err_msg += ", TotalSize: %s Byte(s)\n" % storage_map_item[3] detailed_err_msg += ", TotalSize: %s Byte(s)\n" % storage_map_item[3]
else: else:
detailed_err_msg += "\n" detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % (
total_size, total_size/1024./1024/1024)
detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f BG\n" % (
total_size_inputs, total_size_inputs/1024./1024/1024)
else: else:
hints.append( hints.append(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论