提交 034edd70 authored 作者: James Bergstra's avatar James Bergstra

adding callback argument to VM_Linker

上级 a156506d
import gc
import sys
import time
import unittest
try:
import line_profiler
except ImportError:
......@@ -8,13 +9,50 @@ except ImportError:
import numpy
from theano import function
from theano.gof import vm,link, OpWiseCLinker
from theano.gof import vm
from theano.gof import link
from theano.gof import OpWiseCLinker
from theano.compile import Mode
from theano import tensor
from theano.lazycond import ifelse
import theano
class TestCallbacks(unittest.TestCase):
"""
Test the VM_Linker's callback argument, which can be useful for debugging.
"""
def setUp(self):
self.n_callbacks = {}
def callback(self, node, thunk, storage_map, compute_map):
self.n_callbacks.setdefault(node.op, 0)
self.n_callbacks[node.op] += 1
def test_callback(self):
a, b, c = tensor.scalars('abc')
f = function([a,b,c], (a + b) + c,
mode=Mode(
optimizer=None,
linker=vm.VM_Linker(callback=self.callback)))
f(1, 2, 3)
assert sum(self.n_callbacks.values()) == len(f.maker.env.toposort())
f(1, 2, 3)
assert sum(self.n_callbacks.values()) == len(f.maker.env.toposort()) * 2
def test_callback_with_ifelse(self):
a, b, c = tensor.scalars('abc')
f = function([a,b,c], ifelse(a, 2*b, 2*c),
mode=Mode(
optimizer=None,
linker=vm.VM_Linker(callback=self.callback)))
f(1, 2, 3)
assert self.n_callbacks[ifelse] == 2
def test_speed():
def build_graph(x, depth=5):
......
......@@ -187,7 +187,8 @@ class Stack(VM):
def __init__(self, nodes, thunks, pre_call_clear,
storage_map, compute_map,
env, allow_gc):
env, allow_gc,
callback=None):
super(Stack, self).__init__(nodes, thunks, pre_call_clear)
self.allow_gc = allow_gc
......@@ -199,6 +200,7 @@ class Stack(VM):
self.outputs_size = {}
self.compute_map = compute_map
self.node_idx = node_idx = {}
self.callback = callback
ords = env.orderings()
......@@ -278,6 +280,13 @@ class Stack(VM):
try:
t0 = time.time()
thunks[self.node_idx[current_apply]]()
if self.callback:
self.callback(
current_apply,
thunk=thunks[self.node_idx[current_apply]],
storage_map=storage_map,
compute_map=compute_map,
)
if config.profile:
dt = time.time() - t0
self.apply_time[current_apply] += dt
......@@ -324,6 +333,13 @@ class Stack(VM):
t0 = time.time()
requires = thunks[self.node_idx[current_apply]]()
dt = time.time() - t0
if self.callback:
self.callback(
current_apply,
thunk=thunks[self.node_idx[current_apply]],
storage_map=storage_map,
compute_map=compute_map,
)
self.apply_time[current_apply] += dt
except Exception:
......@@ -377,10 +393,11 @@ class VM_Linker(link.LocalLinker):
Class that satisfies the Linker interface by acting as a VM factory.
"""
def __init__(self, allow_gc=True, use_cloop = False):
def __init__(self, allow_gc=True, use_cloop=False, callback=None):
self.env = None
self.allow_gc = allow_gc
self.use_cloop=use_cloop
self.use_cloop = use_cloop
self.callback = callback
def accept(self, env, no_recycling = []):
"""
......@@ -406,7 +423,13 @@ class VM_Linker(link.LocalLinker):
pre_call_clear = [storage_map[v] for v in self.no_recycling]
if self.use_cloop:
if self.callback is not None:
vm = Stack(
nodes, thunks, pre_call_clear,
storage_map, compute_map,
self.env, self.allow_gc,
callback=self.callback)
elif self.use_cloop:
# create a map from nodes to ints and vars to ints
nodes_idx = {}
vars_idx = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论