提交 cfd4b870 authored 作者: Frederic's avatar Frederic

Make the memory profiler print the shape.

上级 6e6d00ba
...@@ -145,6 +145,9 @@ class ProfileStats(object): ...@@ -145,6 +145,9 @@ class ProfileStats(object):
optimizer_profile = None optimizer_profile = None
# None or tuple (the optimizer, the profile it returned) # None or tuple (the optimizer, the profile it returned)
memory_size_map = {"nt8": 1, "t16": 2, "t32": 4,
"t64": 8, "128": 16}
# param is called flag_time_thunks because most other attributes with time # param is called flag_time_thunks because most other attributes with time
# in the name are times *of* something, rather than configuration flags. # in the name are times *of* something, rather than configuration flags.
def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs): def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs):
...@@ -557,14 +560,26 @@ class ProfileStats(object): ...@@ -557,14 +560,26 @@ class ProfileStats(object):
def summary_memory(self, file, N=None): def summary_memory(self, file, N=None):
fct_memory = {} # fgraph->dict(node->(outputs size)) fct_memory = {} # fgraph->dict(node->(outputs size))
fct_shapes = {} # fgraph->dict(node->[outputs shapes]))
var_mem = {} var_mem = {}
for node, val in self.outputs_size.items(): for node, shapes in self.outputs_size.items():
fct_memory.setdefault(node.fgraph, {}) fct_memory.setdefault(node.fgraph, {})
fct_memory[node.fgraph][node] = val fct_memory[node.fgraph].setdefault(node, [])
for out, v in zip(node.outputs, val): fct_shapes.setdefault(node.fgraph, {})
fct_shapes[node.fgraph].setdefault(node, [])
for out, sh in zip(node.outputs, shapes):
v = numpy.prod(sh)
dtype = str(out.dtype)
v *= self.memory_size_map[dtype[-3:]]
var_mem[out] = v var_mem[out] = v
fct_memory[node.fgraph][node].append(v)
fct_shapes[node.fgraph][node].append(sh)
assert len(fct_memory) == 1
print print
print "Profile of Theano functions memory:" print " Memory Profile"
for fgraph, nodes_mem in fct_memory.iteritems(): for fgraph, nodes_mem in fct_memory.iteritems():
size_sum = sum([sum(val) size_sum = sum([sum(val)
...@@ -622,7 +637,7 @@ class ProfileStats(object): ...@@ -622,7 +637,7 @@ class ProfileStats(object):
node_memory_size - running_max_memory_size) / 1024 node_memory_size - running_max_memory_size) / 1024
print print
print " <Sum apply outputs (bytes)> <Apply outputs memory size(bytes)> <created/inplace/view> <Apply node>" print " <Sum apply outputs (bytes)> <Apply outputs shape> <created/inplace/view> <Apply node>"
print " <created/inplace/view> is taked from the op declaration." print " <created/inplace/view> is taked from the op declaration."
print " Use DebugMode for warnings about inplace/view declaration being respected." print " Use DebugMode for warnings about inplace/view declaration being respected."
print print
...@@ -632,10 +647,11 @@ class ProfileStats(object): ...@@ -632,10 +647,11 @@ class ProfileStats(object):
code[out] = "i" code[out] = "i"
for out, inp in getattr(key.op, 'view_map', {}).iteritems(): for out, inp in getattr(key.op, 'view_map', {}).iteritems():
code[out] = "v" code[out] = "v"
print ' %9dB %s %s %s' % (sum(val), str(val), shapes = str(fct_shapes[fgraph][key])
print ' %9dB %s %s %s' % (sum(val), shapes,
' '.join(code), key) ' '.join(code), key)
sum_remaining = sum(sum(val) for key, val in items[N:]) sum_remaining = sum(sum(shapes) for key, shapes in items[N:])
print (' ... (remaining %i Apply account for %.2f%%(%.2fs) of' print (' ... (remaining %i Apply account for %.2f%%(%.2fs) of'
' the runtime)') % (max(0, len(nodes_mem) - N), ' the runtime)') % (max(0, len(nodes_mem) - N),
sum_remaining, sum_remaining,
......
...@@ -365,16 +365,10 @@ class Stack(VM): ...@@ -365,16 +365,10 @@ class Stack(VM):
for (idx, o) in enumerate( for (idx, o) in enumerate(
thunks[self.node_idx[ thunks[self.node_idx[
current_apply]].outputs): current_apply]].outputs):
if not hasattr(o[0], 'size'): if not hasattr(o[0], 'shape'):
size.append(-1) size.append('no shape')
continue continue
s = o[0].size size.append(o[0].shape)
dtype = str(o[0].dtype)
dtype2 = dtype[-3:]
# KeyError here: couldn't determine
# the dtype memory size
s *= self.memory_size_map[dtype2]
size.append(s)
self.outputs_size[current_apply] = size self.outputs_size[current_apply] = size
except Exception: except Exception:
raise_with_op(current_apply) raise_with_op(current_apply)
...@@ -448,16 +442,10 @@ class Stack(VM): ...@@ -448,16 +442,10 @@ class Stack(VM):
size = [] size = []
for (idx, o) in enumerate(thunks[ for (idx, o) in enumerate(thunks[
self.node_idx[current_apply]].outputs): self.node_idx[current_apply]].outputs):
if not hasattr(o[0], 'size'): if not hasattr(o[0], 'shape'):
size.append(-1) size.append('no shape')
continue continue
s = o[0].size size.append(o[0].shape)
dtype = str(o[0].dtype)
dtype2 = dtype[-2:]
# KeyError here: couldn't determine the
# dtype memory size
s *= self.memory_size_map[dtype2]
size.append(s)
self.outputs_size[current_apply] = size self.outputs_size[current_apply] = size
if self.allow_gc: if self.allow_gc:
for i in current_apply.inputs: for i in current_apply.inputs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论