提交 7d044df1 authored 作者: Frederic's avatar Frederic

Track the shape of all variable in the graph.

Change the mechanism to track in variable_shape instead of outputs_shapes.
上级 93f1f189
...@@ -59,7 +59,7 @@ def _atexit_print_fn(): ...@@ -59,7 +59,7 @@ def _atexit_print_fn():
#merge dictonary #merge dictonary
for attr in ["apply_time", "apply_callcount", for attr in ["apply_time", "apply_callcount",
"apply_cimpl", "outputs_size"]: "apply_cimpl", "variable_shape"]:
cum_attr = getattr(cum, attr) cum_attr = getattr(cum, attr)
for key, val in getattr(ps, attr).iteritems(): for key, val in getattr(ps, attr).iteritems():
assert key not in cum_attr assert key not in cum_attr
...@@ -125,8 +125,8 @@ class ProfileStats(object): ...@@ -125,8 +125,8 @@ class ProfileStats(object):
# pretty string to print in summary, to identify this output # pretty string to print in summary, to identify this output
# #
outputs_size = None variable_shape = {}
# node -> size of allocated output # Variable -> shapes
# #
optimizer_time = 0.0 optimizer_time = 0.0
...@@ -161,7 +161,7 @@ class ProfileStats(object): ...@@ -161,7 +161,7 @@ class ProfileStats(object):
self.output_size = {} self.output_size = {}
self.apply_time = {} self.apply_time = {}
self.apply_cimpl = {} self.apply_cimpl = {}
self.outputs_size = {} self.variable_shape = {}
if flag_time_thunks is None: if flag_time_thunks is None:
self.flag_time_thunks = config.profiling.time_thunks self.flag_time_thunks = config.profiling.time_thunks
else: else:
...@@ -562,13 +562,15 @@ class ProfileStats(object): ...@@ -562,13 +562,15 @@ class ProfileStats(object):
fct_memory = {} # fgraph->dict(node->(outputs size)) fct_memory = {} # fgraph->dict(node->(outputs size))
fct_shapes = {} # fgraph->dict(node->[outputs shapes])) fct_shapes = {} # fgraph->dict(node->[outputs shapes]))
var_mem = {} var_mem = {}
for node, shapes in self.outputs_size.items():
for node in self.apply_callcount.keys():
fct_memory.setdefault(node.fgraph, {}) fct_memory.setdefault(node.fgraph, {})
fct_memory[node.fgraph].setdefault(node, []) fct_memory[node.fgraph].setdefault(node, [])
fct_shapes.setdefault(node.fgraph, {}) fct_shapes.setdefault(node.fgraph, {})
fct_shapes[node.fgraph].setdefault(node, []) fct_shapes[node.fgraph].setdefault(node, [])
for out, sh in zip(node.outputs, shapes): for out in node.outputs:
sh = self.variable_shape[out]
v = numpy.prod(sh) v = numpy.prod(sh)
dtype = str(out.dtype) dtype = str(out.dtype)
v *= self.memory_size_map[dtype[-3:]] v *= self.memory_size_map[dtype[-3:]]
...@@ -668,7 +670,7 @@ class ProfileStats(object): ...@@ -668,7 +670,7 @@ class ProfileStats(object):
elif self.fct_callcount > 0: elif self.fct_callcount > 0:
print >> file, (" No node time accumulated " print >> file, (" No node time accumulated "
"(hint: try config profiling.time_thunks=1)") "(hint: try config profiling.time_thunks=1)")
if self.outputs_size: if self.variable_shape:
self.summary_memory(file, n_ops_to_print) self.summary_memory(file, n_ops_to_print)
if self.optimizer_profile: if self.optimizer_profile:
print "Optimizer Profile" print "Optimizer Profile"
......
...@@ -130,7 +130,7 @@ class VM(object): ...@@ -130,7 +130,7 @@ class VM(object):
profile.apply_cimpl[node] = hasattr(thunk, 'cthunk') profile.apply_cimpl[node] = hasattr(thunk, 'cthunk')
profile.outputs_size[node] = self.outputs_size[node] profile.variable_shape = self.variable_shape.copy()
# clear the timer info out of the buffers # clear the timer info out of the buffers
for i in xrange(len(self.call_times)): for i in xrange(len(self.call_times)):
...@@ -246,7 +246,7 @@ class Stack(VM): ...@@ -246,7 +246,7 @@ class Stack(VM):
self.base_apply_stack = [o.owner for o in fgraph.outputs if o.owner] self.base_apply_stack = [o.owner for o in fgraph.outputs if o.owner]
self.outputs = fgraph.outputs self.outputs = fgraph.outputs
self.storage_map = storage_map self.storage_map = storage_map
self.outputs_size = {} self.variable_shape = {} # Variable -> shape
self.compute_map = compute_map self.compute_map = compute_map
self.node_idx = node_idx = {} self.node_idx = node_idx = {}
self.callback = callback self.callback = callback
...@@ -255,7 +255,7 @@ class Stack(VM): ...@@ -255,7 +255,7 @@ class Stack(VM):
for i, node in enumerate(self.nodes): for i, node in enumerate(self.nodes):
node_idx[node] = i node_idx[node] = i
self.outputs_size[node] = []
# XXX: inconsistent style - why modify node here rather # XXX: inconsistent style - why modify node here rather
# than track destroy_dependencies with dictionary like # than track destroy_dependencies with dictionary like
# storage_map? # storage_map?
...@@ -319,6 +319,17 @@ class Stack(VM): ...@@ -319,6 +319,17 @@ class Stack(VM):
apply_stack = list(self.base_apply_stack) apply_stack = list(self.base_apply_stack)
last_apply_stack_len = -1 last_apply_stack_len = -1
ls = [] ls = []
#This record all function inputs/shared varibles and constants
for var, data in self.storage_map.iteritems():
if data[0] is None:
continue
if not hasattr(data[0], 'shape'):
sh = 'input no shape'
else:
sh = data[0].shape
self.variable_shape[var] = sh
while apply_stack: while apply_stack:
# Make sure something happened last time round. This is # Make sure something happened last time round. This is
# just a safety check to make sure the op is written # just a safety check to make sure the op is written
...@@ -355,21 +366,21 @@ class Stack(VM): ...@@ -355,21 +366,21 @@ class Stack(VM):
_, dt = self.run_thunk_of_node(current_apply) _, dt = self.run_thunk_of_node(current_apply)
del _ del _
if config.profile: if config.profile:
nodes_idx = self.nodes.index(current_apply) current_idx = self.node_idx[current_apply]
self.call_counts[nodes_idx] += 1 self.call_counts[current_idx] += 1
self.call_times[nodes_idx] += dt self.call_times[current_idx] += dt
## Computing the memory footprint of the the op ## Computing the memory footprint of the the op
# ?? What about inplace .. if the op is inplace # ?? What about inplace .. if the op is inplace
# you don't actually ask for more memory! # you don't actually ask for more memory!
size = []
for (idx, o) in enumerate( for (idx, o) in enumerate(
thunks[self.node_idx[ thunks[self.node_idx[
current_apply]].outputs): current_apply]].outputs):
if not hasattr(o[0], 'shape'): if not hasattr(o[0], 'shape'):
size.append('no shape') sh = 'no shape'
continue else:
size.append(o[0].shape) sh = o[0].shape
self.outputs_size[current_apply] = size var = self.nodes[current_idx].outputs[idx]
self.variable_shape[var] = sh
except Exception: except Exception:
raise_with_op(current_apply) raise_with_op(current_apply)
for o in current_apply.outputs: for o in current_apply.outputs:
...@@ -424,9 +435,9 @@ class Stack(VM): ...@@ -424,9 +435,9 @@ class Stack(VM):
try: try:
requires, dt = self.run_thunk_of_node(current_apply) requires, dt = self.run_thunk_of_node(current_apply)
nodes_idx = self.nodes.index(current_apply) current_idx = self.node_idx[current_apply]
self.call_counts[nodes_idx] += 1 self.call_counts[current_idx] += 1
self.call_times[nodes_idx] += dt self.call_times[current_idx] += dt
except Exception: except Exception:
raise_with_op(current_apply) raise_with_op(current_apply)
...@@ -441,14 +452,14 @@ class Stack(VM): ...@@ -441,14 +452,14 @@ class Stack(VM):
apply_stack.append(current_apply.inputs[r].owner) apply_stack.append(current_apply.inputs[r].owner)
else: else:
if config.profile: if config.profile:
size = []
for (idx, o) in enumerate(thunks[ for (idx, o) in enumerate(thunks[
self.node_idx[current_apply]].outputs): self.node_idx[current_apply]].outputs):
if not hasattr(o[0], 'shape'): if not hasattr(o[0], 'shape'):
size.append('no shape') sh = 'no shape'
continue else:
size.append(o[0].shape) sh = o[0].shape
self.outputs_size[current_apply] = size var = self.nodes[self.node_idx[current_apply]].outputs[idx]
self.variable_shape[var] = sh
if self.allow_gc: if self.allow_gc:
for i in current_apply.inputs: for i in current_apply.inputs:
if (dependencies[i] and i.owner and if (dependencies[i] and i.owner and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论