提交 150605d8 authored 作者: James Bergstra's avatar James Bergstra

refactored Stack VM to connect thunk execution with the callback

上级 a79e9eeb
...@@ -228,6 +228,24 @@ class Stack(VM): ...@@ -228,6 +228,24 @@ class Stack(VM):
self.memory_size_map = {"nt8": 1, "t16": 2, "t32": 4, "t64": 8, "128": 16} self.memory_size_map = {"nt8": 1, "t16": 2, "t32": 4, "t64": 8, "128": 16}
atexit.register(self.atexit_print_all) atexit.register(self.atexit_print_all)
def run_thunk_of_node(self, node):
"""Run the thunk corresponding to Apply instance `node`
Calls self.callback if it is defined.
"""
idx = self.node_idx[node]
t0 = time.time()
rval = self.thunks[idx]()
dt = max(time.time() - t0, 1e-10)
if self.callback is not None:
self.callback(
node,
thunk=self.thunks[idx],
storage_map=self.storage_map,
compute_map=self.compute_map,
)
return rval, dt
def __call__(self): def __call__(self):
storage_map = self.storage_map storage_map = self.storage_map
compute_map = self.compute_map compute_map = self.compute_map
...@@ -278,17 +296,9 @@ class Stack(VM): ...@@ -278,17 +296,9 @@ class Stack(VM):
if computed_ins and not computed_outs: if computed_ins and not computed_outs:
try: try:
t0 = time.time() _, dt = self.run_thunk_of_node(current_apply)
thunks[self.node_idx[current_apply]]() del _
if self.callback:
self.callback(
current_apply,
thunk=thunks[self.node_idx[current_apply]],
storage_map=storage_map,
compute_map=compute_map,
)
if config.profile: if config.profile:
dt = time.time() - t0
self.apply_time[current_apply] += dt self.apply_time[current_apply] += dt
## Computing the memory footprint of the the op ## Computing the memory footprint of the the op
# ?? What about inplace .. if the op is inplace # ?? What about inplace .. if the op is inplace
...@@ -330,16 +340,7 @@ class Stack(VM): ...@@ -330,16 +340,7 @@ class Stack(VM):
elif not computed_outs: elif not computed_outs:
# Try and run it to see if it works # Try and run it to see if it works
try: try:
t0 = time.time() requires, dt = self.run_thunk_of_node(current_apply)
requires = thunks[self.node_idx[current_apply]]()
dt = time.time() - t0
if self.callback:
self.callback(
current_apply,
thunk=thunks[self.node_idx[current_apply]],
storage_map=storage_map,
compute_map=compute_map,
)
self.apply_time[current_apply] += dt self.apply_time[current_apply] += dt
except Exception: except Exception:
...@@ -352,13 +353,11 @@ class Stack(VM): ...@@ -352,13 +353,11 @@ class Stack(VM):
apply_stack.append(current_apply) apply_stack.append(current_apply)
if current_apply.inputs[r].owner: if current_apply.inputs[r].owner:
apply_stack.append(current_apply.inputs[r].owner) apply_stack.append(current_apply.inputs[r].owner)
else: else:
if config.profile: if config.profile:
size = [] size = []
for (idx,o) in enumerate(thunks[self.node_idx[current_apply]].outputs): for (idx,o) in enumerate(thunks[self.node_idx[current_apply]].outputs):
if not hasattr(o[0],'size'): if not hasattr(o[0], 'size'):
size.append(-1) size.append(-1)
continue continue
s=o[0].size s=o[0].size
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论