提交 1644bd5f authored 作者: Frederic's avatar Frederic

In the memory profile, also collect and print the strides information.

上级 5d252340
...@@ -59,7 +59,7 @@ def _atexit_print_fn(): ...@@ -59,7 +59,7 @@ def _atexit_print_fn():
#merge dictonary #merge dictonary
for attr in ["apply_time", "apply_callcount", for attr in ["apply_time", "apply_callcount",
"apply_cimpl", "variable_shape"]: "apply_cimpl", "variable_shape", "variable_strides"]:
cum_attr = getattr(cum, attr) cum_attr = getattr(cum, attr)
for key, val in getattr(ps, attr).iteritems(): for key, val in getattr(ps, attr).iteritems():
assert key not in cum_attr assert key not in cum_attr
...@@ -129,6 +129,10 @@ class ProfileStats(object): ...@@ -129,6 +129,10 @@ class ProfileStats(object):
# Variable -> shapes # Variable -> shapes
# #
variable_strides = {}
# Variable -> strides
#
optimizer_time = 0.0 optimizer_time = 0.0
# time spent optimizing graph (FunctionMaker.__init__) # time spent optimizing graph (FunctionMaker.__init__)
...@@ -162,6 +166,7 @@ class ProfileStats(object): ...@@ -162,6 +166,7 @@ class ProfileStats(object):
self.apply_time = {} self.apply_time = {}
self.apply_cimpl = {} self.apply_cimpl = {}
self.variable_shape = {} self.variable_shape = {}
self.variable_strides = {}
if flag_time_thunks is None: if flag_time_thunks is None:
self.flag_time_thunks = config.profiling.time_thunks self.flag_time_thunks = config.profiling.time_thunks
else: else:
...@@ -527,12 +532,16 @@ class ProfileStats(object): ...@@ -527,12 +532,16 @@ class ProfileStats(object):
continue continue
for idx, var in enumerate(a.inputs): for idx, var in enumerate(a.inputs):
sh = self.variable_shape.get(var, 'no shape') sh = self.variable_shape.get(var, 'no shape')
st = self.variable_strides.get(var, 'no strides')
dtype = getattr(var, 'dtype', 'no dtype') dtype = getattr(var, 'dtype', 'no dtype')
print " input %d: dtype=%s, shape=%s " % (idx, dtype, sh) print " input %d: dtype=%s, shape=%s, strides=%s " % (
idx, dtype, sh, st)
for idx, var in enumerate(a.outputs): for idx, var in enumerate(a.outputs):
sh = self.variable_shape.get(var, 'no shape') sh = self.variable_shape.get(var, 'no shape')
st = self.variable_strides.get(var, 'no strides')
dtype = getattr(var, 'dtype', 'no dtype') dtype = getattr(var, 'dtype', 'no dtype')
print " output %d: dtype=%s, shape=%s " % (idx, dtype, sh) print " output %d: dtype=%s, shape=%s, strides=%s " % (
idx, dtype, sh, st)
# Same as before, this I've sacrificied some information making # Same as before, this I've sacrificied some information making
# the output more readable # the output more readable
#print >> file, ' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %i %s'%( #print >> file, ' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %i %s'%(
......
...@@ -134,6 +134,7 @@ class VM(object): ...@@ -134,6 +134,7 @@ class VM(object):
profile.apply_cimpl[node] = hasattr(thunk, 'cthunk') profile.apply_cimpl[node] = hasattr(thunk, 'cthunk')
profile.variable_shape = self.variable_shape.copy() profile.variable_shape = self.variable_shape.copy()
profile.variable_strides = self.variable_strides.copy()
# clear the timer info out of the buffers # clear the timer info out of the buffers
for i in xrange(len(self.call_times)): for i in xrange(len(self.call_times)):
...@@ -250,6 +251,7 @@ class Stack(VM): ...@@ -250,6 +251,7 @@ class Stack(VM):
self.outputs = fgraph.outputs self.outputs = fgraph.outputs
self.storage_map = storage_map self.storage_map = storage_map
self.variable_shape = {} # Variable -> shape self.variable_shape = {} # Variable -> shape
self.variable_strides = {} # Variable -> strides
self.compute_map = compute_map self.compute_map = compute_map
self.node_idx = node_idx = {} self.node_idx = node_idx = {}
self.callback = callback self.callback = callback
...@@ -327,11 +329,12 @@ class Stack(VM): ...@@ -327,11 +329,12 @@ class Stack(VM):
for var, data in self.storage_map.iteritems(): for var, data in self.storage_map.iteritems():
if data[0] is None: if data[0] is None:
continue continue
if not hasattr(data[0], 'shape'): sh = getattr(data[0], 'shape', 'input no shape')
sh = 'input no shape'
else:
sh = data[0].shape
self.variable_shape[var] = sh self.variable_shape[var] = sh
st = getattr(data[0], 'strides', 'input no strides')
if getattr(data[0], 'flags', False) and data[0].flags.c_contiguous:
st = 'c'
self.variable_strides[var] = st
while apply_stack: while apply_stack:
# Make sure something happened last time round. This is # Make sure something happened last time round. This is
...@@ -378,12 +381,15 @@ class Stack(VM): ...@@ -378,12 +381,15 @@ class Stack(VM):
for (idx, o) in enumerate( for (idx, o) in enumerate(
thunks[self.node_idx[ thunks[self.node_idx[
current_apply]].outputs): current_apply]].outputs):
if not hasattr(o[0], 'shape'):
sh = 'no shape'
else:
sh = o[0].shape
var = self.nodes[current_idx].outputs[idx] var = self.nodes[current_idx].outputs[idx]
sh = getattr(o[0], 'shape', 'input no shape')
self.variable_shape[var] = sh self.variable_shape[var] = sh
st = getattr(o[0], 'strides',
'input no strides')
if (getattr(o[0], 'flags', False) and
o[0].flags.c_contiguous):
st = 'c'
self.variable_strides[var] = st
except Exception: except Exception:
raise_with_op(current_apply) raise_with_op(current_apply)
for o in current_apply.outputs: for o in current_apply.outputs:
...@@ -457,12 +463,16 @@ class Stack(VM): ...@@ -457,12 +463,16 @@ class Stack(VM):
if config.profile: if config.profile:
for (idx, o) in enumerate(thunks[ for (idx, o) in enumerate(thunks[
self.node_idx[current_apply]].outputs): self.node_idx[current_apply]].outputs):
if not hasattr(o[0], 'shape'): var = self.nodes[
sh = 'no shape' self.node_idx[current_apply]].outputs[idx]
else: sh = getattr(o[0], 'shape', 'input no shape')
sh = o[0].shape
var = self.nodes[self.node_idx[current_apply]].outputs[idx]
self.variable_shape[var] = sh self.variable_shape[var] = sh
st = getattr(o[0], 'strides', 'input no strides')
if (getattr(o[0], 'flags', False) and
o[0].flags.c_contiguous):
st = 'c'
self.variable_strides[var] = st
if self.allow_gc: if self.allow_gc:
for i in current_apply.inputs: for i in current_apply.inputs:
if (dependencies[i] and i.owner and if (dependencies[i] and i.owner and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论