提交 c3b121e0 authored 作者: FanZiye(t13m)'s avatar FanZiye(t13m)

remove useless if in link.py; tell if a tensor is shared

上级 ff085efc
......@@ -172,38 +172,55 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
# Prints output_map
if storage_map is not None:
detailed_err_msg += "\nStorage map footprint:\n"
shared_input_list = [
item for item in node.fgraph.inputs
if isinstance(item, theano.compile.SharedVariable)]
storage_map_list = []
for k in storage_map.keys():
if storage_map[k][0] is not None:
storage_map_item = []
storage_map_item.append(str(k))
shapeinfo = None
if hasattr(storage_map[k][0], 'shape'):
shapeinfo = storage_map[k][0].shape
if len(shapeinfo) != 0:
storage_map_item.append(shapeinfo)
else:
storage_map_item.append((1,))
else:
storage_map_item.append(None)
storage_map_item = []
if hasattr(storage_map[k][0], 'dtype'):
dtype = storage_map[k][0].dtype
storage_map_item.append(numpy.dtype(dtype).itemsize)
if shapeinfo is None:
storage_map_item.append(None)
else:
storage_map_item.append(numpy.dtype(dtype).itemsize * numpy.prod(shapeinfo))
# storage_map_item[0]
storage_map_item.append(str(k))
# storage_map_item[1]
shapeinfo = None
if hasattr(storage_map[k][0], 'shape'):
shapeinfo = storage_map[k][0].shape
if len(shapeinfo) != 0:
storage_map_item.append(shapeinfo)
else:
bytes = getsizeof(storage_map[k][0])
storage_map_item.append(bytes)
storage_map_item.append(tuple())
else:
storage_map_item.append(None)
# storage_map_item[2]
# storage_map_item[3]
if hasattr(storage_map[k][0], 'dtype'):
dtype = storage_map[k][0].dtype
storage_map_item.append(numpy.dtype(dtype).itemsize)
if shapeinfo is None:
storage_map_item.append(None)
storage_map_list.append(storage_map_item)
else:
storage_map_item.append(numpy.dtype(dtype).itemsize * numpy.prod(shapeinfo))
else:
bytes = getsizeof(storage_map[k][0])
storage_map_item.append(bytes)
storage_map_item.append(None)
# Flag of shared val
# storage_map_item[4]
if k in shared_input_list:
storage_map_item.append(True)
else:
storage_map_item.append(False)
storage_map_list.append(storage_map_item)
from operator import itemgetter
storage_map_list.sort(key=itemgetter(3), reverse=True)
for storage_map_item in storage_map_list:
detailed_err_msg += " - " + storage_map_item[0] + ", "
if storage_map_item[4] is True:
detailed_err_msg += "Shared, "
if storage_map_item[1] is not None:
detailed_err_msg += "Shape: %s, " % str(storage_map_item[1])
detailed_err_msg += "ElemSize: %s Byte(s)" % storage_map_item[2]
......@@ -218,7 +235,6 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
" for a debugprint and storage map footprint of this apply node.")
exc_value = exc_type(str(exc_value) + detailed_err_msg +
'\n' + '\n'.join(hints))
raise exc_type, exc_value, exc_trace
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论