提交 74724279 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #111 from jaberg/vm_callback

adding callback argument to VM_Linker
import gc
import sys
import time
import unittest
try:
import line_profiler
except ImportError:
......@@ -8,13 +9,50 @@ except ImportError:
import numpy
from theano import function
from theano.gof import vm,link, OpWiseCLinker
from theano.gof import vm
from theano.gof import link
from theano.gof import OpWiseCLinker
from theano.compile import Mode
from theano import tensor
from theano.lazycond import ifelse
import theano
class TestCallbacks(unittest.TestCase):
"""
Test the VM_Linker's callback argument, which can be useful for debugging.
"""
def setUp(self):
self.n_callbacks = {}
def callback(self, node, thunk, storage_map, compute_map):
self.n_callbacks.setdefault(node.op, 0)
self.n_callbacks[node.op] += 1
def test_callback(self):
a, b, c = tensor.scalars('abc')
f = function([a,b,c], (a + b) + c,
mode=Mode(
optimizer=None,
linker=vm.VM_Linker(callback=self.callback)))
f(1, 2, 3)
assert sum(self.n_callbacks.values()) == len(f.maker.env.toposort())
f(1, 2, 3)
assert sum(self.n_callbacks.values()) == len(f.maker.env.toposort()) * 2
def test_callback_with_ifelse(self):
a, b, c = tensor.scalars('abc')
f = function([a,b,c], ifelse(a, 2*b, 2*c),
mode=Mode(
optimizer=None,
linker=vm.VM_Linker(callback=self.callback)))
f(1, 2, 3)
assert self.n_callbacks[ifelse] == 2
def test_speed():
def build_graph(x, depth=5):
......
"""
VMs that run Theano graph computations.
"""
import logging
import sys
import time
import link
......@@ -13,6 +14,8 @@ config = theano.config
from theano.configparser import config, AddConfigVar, BoolParam
from theano import config
logger = logging.getLogger(__name__)
AddConfigVar('profile',
"If VM should collect profile information",
BoolParam(False))
......@@ -187,7 +190,8 @@ class Stack(VM):
def __init__(self, nodes, thunks, pre_call_clear,
storage_map, compute_map,
env, allow_gc):
env, allow_gc,
callback=None):
super(Stack, self).__init__(nodes, thunks, pre_call_clear)
self.allow_gc = allow_gc
......@@ -199,6 +203,7 @@ class Stack(VM):
self.outputs_size = {}
self.compute_map = compute_map
self.node_idx = node_idx = {}
self.callback = callback
ords = env.orderings()
......@@ -226,6 +231,28 @@ class Stack(VM):
self.memory_size_map = {"nt8": 1, "t16": 2, "t32": 4, "t64": 8, "128": 16}
atexit.register(self.atexit_print_all)
def run_thunk_of_node(self, node):
"""Run the thunk corresponding to Apply instance `node`
Calls self.callback if it is defined.
"""
idx = self.node_idx[node]
t0 = time.time()
rval = self.thunks[idx]()
# Some thunks on some computers run faster than the granularity
# of the time.time clock.
# Profile output looks buggy if a node has run but takes 0 time.
dt = max(time.time() - t0, 1e-10)
if self.callback is not None:
self.callback(
node=node,
thunk=self.thunks[idx],
storage_map=self.storage_map,
compute_map=self.compute_map,
)
return rval, dt
def __call__(self):
storage_map = self.storage_map
compute_map = self.compute_map
......@@ -276,10 +303,9 @@ class Stack(VM):
if computed_ins and not computed_outs:
try:
t0 = time.time()
thunks[self.node_idx[current_apply]]()
_, dt = self.run_thunk_of_node(current_apply)
del _
if config.profile:
dt = time.time() - t0
self.apply_time[current_apply] += dt
## Computing the memory footprint of the the op
# ?? What about inplace .. if the op is inplace
......@@ -321,9 +347,7 @@ class Stack(VM):
elif not computed_outs:
# Try and run it to see if it works
try:
t0 = time.time()
requires = thunks[self.node_idx[current_apply]]()
dt = time.time() - t0
requires, dt = self.run_thunk_of_node(current_apply)
self.apply_time[current_apply] += dt
except Exception:
......@@ -336,13 +360,11 @@ class Stack(VM):
apply_stack.append(current_apply)
if current_apply.inputs[r].owner:
apply_stack.append(current_apply.inputs[r].owner)
else:
if config.profile:
size = []
for (idx,o) in enumerate(thunks[self.node_idx[current_apply]].outputs):
if not hasattr(o[0],'size'):
if not hasattr(o[0], 'size'):
size.append(-1)
continue
s=o[0].size
......@@ -377,10 +399,23 @@ class VM_Linker(link.LocalLinker):
Class that satisfies the Linker interface by acting as a VM factory.
"""
def __init__(self, allow_gc=True, use_cloop = False):
def __init__(self, allow_gc=True, use_cloop=False, callback=None):
"""
allow_gc - force the virtual machine to clean up unnecessary references,
in order to allow garbage collection on intermediate values during
computation of a function.
use_cloop - use the C-based virtual machine if possible
callback - a callable object to call after each call to a thunk within
the virtual machine. It will be called with four arguments called
'node', 'thunk', 'storage_map', and 'compute_map'.
"""
self.env = None
self.allow_gc = allow_gc
self.use_cloop=use_cloop
self.use_cloop = use_cloop
self.callback = callback
def accept(self, env, no_recycling = []):
"""
......@@ -406,7 +441,15 @@ class VM_Linker(link.LocalLinker):
pre_call_clear = [storage_map[v] for v in self.no_recycling]
if self.use_cloop:
if self.callback is not None:
if use_cloop:
logger.warn('CLoop does not support callback, using Stack VM.')
vm = Stack(
nodes, thunks, pre_call_clear,
storage_map, compute_map,
self.env, self.allow_gc,
callback=self.callback)
elif self.use_cloop:
# create a map from nodes to ints and vars to ints
nodes_idx = {}
vars_idx = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论