提交 148814bf authored 作者: nouiz's avatar nouiz

Merge pull request #311 from jaberg/cvm_runs_testsuite

Cvm runs testsuite
...@@ -21,6 +21,7 @@ from theano.configdefaults import config ...@@ -21,6 +21,7 @@ from theano.configdefaults import config
import logging import logging
_logger = logging.getLogger('theano.compile.function_module') _logger = logging.getLogger('theano.compile.function_module')
def alias_root(v): def alias_root(v):
"""Return the variable to which v is aliased by view_maps and destroy_maps""" """Return the variable to which v is aliased by view_maps and destroy_maps"""
if v.owner is None: return v if v.owner is None: return v
...@@ -35,6 +36,7 @@ def alias_root(v): ...@@ -35,6 +36,7 @@ def alias_root(v):
else: else:
return v return v
def view_tree_set(v, treeset): def view_tree_set(v, treeset):
"""Add to `treeset` all variables that are views of v, given that v is not a view""" """Add to `treeset` all variables that are views of v, given that v is not a view"""
treeset.add(v) treeset.add(v)
...@@ -48,6 +50,7 @@ def view_tree_set(v, treeset): ...@@ -48,6 +50,7 @@ def view_tree_set(v, treeset):
if cl.outputs[opos] not in treeset: if cl.outputs[opos] not in treeset:
view_tree_set(cl.outputs[opos], treeset) view_tree_set(cl.outputs[opos], treeset)
def infer_reuse_pattern(env, outputs_to_disown): def infer_reuse_pattern(env, outputs_to_disown):
""" """
Given an env and a list of variables, returns the list or set of all variables which may Given an env and a list of variables, returns the list or set of all variables which may
...@@ -65,6 +68,23 @@ def infer_reuse_pattern(env, outputs_to_disown): ...@@ -65,6 +68,23 @@ def infer_reuse_pattern(env, outputs_to_disown):
return rval return rval
def env_updated_vars(env, expanded_inputs):
"""
Reconstruct the full "updates" dictionary, mapping from Env input
variables to the env outputs that will replace their values.
:rtype: dict variable -> variable
"""
updated_vars = {}
potential_values = list(env.outputs) # copy the list
if len(expanded_inputs) != len(env.inputs):
raise ValueError('expanded_inputs must match len(env.inputs)')
for e_input, ivar in reversed(zip(expanded_inputs, env.inputs)):
if e_input.update is not None:
updated_vars[ivar] = potential_values.pop()
return updated_vars
class Supervisor: class Supervisor:
""" """
Listener for Env events which makes sure that no operation overwrites the Listener for Env events which makes sure that no operation overwrites the
...@@ -550,6 +570,7 @@ class Function(object): ...@@ -550,6 +570,7 @@ class Function(object):
# Set keyword arguments # Set keyword arguments
if kwargs: # for speed, skip the iteritems for empty kwargs
for k, arg in kwargs.iteritems(): for k, arg in kwargs.iteritems():
self[k] = arg self[k] = arg
...@@ -610,7 +631,7 @@ class Function(object): ...@@ -610,7 +631,7 @@ class Function(object):
# Do the actual work # Do the actual work
t0_fn = time.time() t0_fn = time.time()
try: try:
self.fn() outputs = self.fn()
except Exception: except Exception:
if hasattr(self.fn, 'position_of_error'): if hasattr(self.fn, 'position_of_error'):
# this is a new vm-provided function # this is a new vm-provided function
...@@ -627,7 +648,9 @@ class Function(object): ...@@ -627,7 +648,9 @@ class Function(object):
profile.vm_call_time += dt_fn profile.vm_call_time += dt_fn
# Retrieve the values that were computed # Retrieve the values that were computed
if outputs is None:
outputs = [x.data for x in self.output_storage] outputs = [x.data for x in self.output_storage]
assert len(outputs) == len(self.output_storage)
# Remove internal references to required inputs. # Remove internal references to required inputs.
# These cannot be re-used anyway. # These cannot be re-used anyway.
...@@ -1027,8 +1050,10 @@ class FunctionMaker(object): ...@@ -1027,8 +1050,10 @@ class FunctionMaker(object):
else: else:
self.linker = linker.accept(env) self.linker = linker.accept(env)
#hacky thing so VMLinker if hasattr(linker, 'accept_var_updates'):
self.linker.expanded_inputs = expanded_inputs # hacky thing so VMLinker knows about updates
self.linker.accept_var_updates(
env_updated_vars(env, expanded_inputs))
self.indices = indices self.indices = indices
self.inputs = inputs self.inputs = inputs
......
...@@ -28,9 +28,9 @@ class T_bunch_of_modes(unittest.TestCase): ...@@ -28,9 +28,9 @@ class T_bunch_of_modes(unittest.TestCase):
x = T.matrix() x = T.matrix()
y = T.vector() y = T.vector()
f = theano.function([x,y], x+y, mode=mode) f = theano.function([x, y], x + y, mode=mode)
# test that it runs something # test that it runs something
f([[1,2],[3,4]], [5, 6]) f([[1, 2], [3, 4]], [5, 6])
linker_classes_involved.append(f.maker.mode.linker.__class__) linker_classes_involved.append(f.maker.mode.linker.__class__)
print 'MODE:', mode, f.maker.mode.linker, 'stop' print 'MODE:', mode, f.maker.mode.linker, 'stop'
# regression check: # regression check:
......
...@@ -470,16 +470,25 @@ static PyObject * pycall(CLazyLinker * self, Py_ssize_t node_idx, int verbose) ...@@ -470,16 +470,25 @@ static PyObject * pycall(CLazyLinker * self, Py_ssize_t node_idx, int verbose)
double t0 = pytime(NULL); double t0 = pytime(NULL);
if (verbose) fprintf(stderr, "calling via Python (node %i)\n", (int)node_idx); if (verbose) fprintf(stderr, "calling via Python (node %i)\n", (int)node_idx);
rval = PyObject_CallObject(thunk, NULL); rval = PyObject_CallObject(thunk, NULL);
if (rval)
{
double t1 = pytime(NULL); double t1 = pytime(NULL);
double ti = PyFloat_AsDouble(PyList_GetItem(self->call_times, node_idx)); double ti = PyFloat_AsDouble(
PyList_SetItem(self->call_times, node_idx, PyFloat_FromDouble(t1 - t0 + ti)); PyList_GetItem(self->call_times, node_idx));
PyList_SetItem(self->call_times, node_idx,
PyFloat_FromDouble(t1 - t0 + ti));
PyObject * count = PyList_GetItem(self->call_counts, node_idx); PyObject * count = PyList_GetItem(self->call_counts, node_idx);
long icount = PyInt_AsLong(count); long icount = PyInt_AsLong(count);
PyList_SetItem(self->call_counts, node_idx, PyInt_FromLong(icount+1)); PyList_SetItem(self->call_counts, node_idx,
PyInt_FromLong(icount + 1));
}
} }
else else
{ {
if (verbose) fprintf(stderr, "calling via Python (node %i)\n", (int)node_idx); if (verbose)
{
fprintf(stderr, "calling via Python (node %i)\n", (int)node_idx);
}
rval = PyObject_CallObject(thunk, NULL); rval = PyObject_CallObject(thunk, NULL);
} }
return rval; return rval;
...@@ -730,8 +739,13 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds) ...@@ -730,8 +739,13 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
return NULL; return NULL;
int err = 0; int err = 0;
self->position_of_error = -1; self->position_of_error = -1;
// create constants used to fill the var_compute_cells
PyObject * one = PyInt_FromLong(1); PyObject * one = PyInt_FromLong(1);
PyObject * zero = PyInt_FromLong(0); PyObject * zero = PyInt_FromLong(0);
// pre-allocate our return value
Py_INCREF(Py_None);
PyObject * rval = Py_None;
//clear storage of pre_call_clear elements //clear storage of pre_call_clear elements
for (int call_i = 0; call_i < n_calls && (!err); ++call_i) for (int call_i = 0; call_i < n_calls && (!err); ++call_i)
{ {
...@@ -764,22 +778,41 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds) ...@@ -764,22 +778,41 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
err = lazy_rec_eval(self, self->output_vars[i], one, zero); err = lazy_rec_eval(self, self->output_vars[i], one, zero);
} }
if (!err)
{
// save references to outputs prior to updating storage containers
if ((call_i + 1) == n_calls)
{
assert (self->n_output_vars >= self->n_updates);
Py_DECREF(rval);
rval = PyList_New(self->n_output_vars);
for (int i = 0; i < (self->n_output_vars); ++i)
{
Py_ssize_t src = self->output_vars[i];
PyObject * item = PyList_GetItem(self->var_value_cells[src], 0);
Py_INCREF(item);
PyList_SetItem(rval, i, item);
}
}
for (int i = 0; i < self->n_updates; ++i) for (int i = 0; i < self->n_updates; ++i)
{ {
Py_ssize_t dst = self->update_storage[2*i]; Py_ssize_t dst = self->update_storage[2*i];
Py_ssize_t src = self->update_storage[2*i+1]; Py_ssize_t src = self->update_storage[2*i+1];
PyObject* tmp = PyList_GetItem(self->var_value_cells[src], 0); PyObject* tmp = PyList_GetItem(self->var_value_cells[src], 0);
Py_INCREF(Py_None);
Py_INCREF(tmp); Py_INCREF(tmp);
PyList_SetItem(self->var_value_cells[dst], 0, tmp); PyList_SetItem(self->var_value_cells[dst], 0, tmp);
PyList_SetItem(self->var_value_cells[src], 0, Py_None); }
} }
} }
Py_DECREF(one); Py_DECREF(one);
Py_DECREF(zero); Py_DECREF(zero);
if (err) return NULL; if (err)
Py_INCREF(Py_None); {
return Py_None; Py_DECREF(rval);
return NULL;
}
return rval;
} }
#if 0 #if 0
...@@ -853,7 +886,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = { ...@@ -853,7 +886,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = {
static PyObject * get_version(PyObject *dummy, PyObject *args) static PyObject * get_version(PyObject *dummy, PyObject *args)
{ {
PyObject *result = PyFloat_FromDouble(0.1); PyObject *result = PyFloat_FromDouble(0.13);
return result; return result;
} }
......
...@@ -12,12 +12,17 @@ _logger = logging.getLogger('theano.gof.lazylinker_c') ...@@ -12,12 +12,17 @@ _logger = logging.getLogger('theano.gof.lazylinker_c')
if config.compiledir not in sys.path: if config.compiledir not in sys.path:
sys.path.append(config.compiledir) sys.path.append(config.compiledir)
version = 0.1 # must match constant returned in function get_version() force_compile = False
version = 0.13 # must match constant returned in function get_version()
need_reload = False
try: try:
_need_reload = False
if force_compile:
raise ImportError()
else:
import lazylinker_ext import lazylinker_ext
need_reload = True _need_reload = True
if version != getattr(lazylinker_ext, '_version', None): if version != getattr(lazylinker_ext, '_version', None):
raise ImportError() raise ImportError()
except ImportError: except ImportError:
...@@ -26,13 +31,15 @@ except ImportError: ...@@ -26,13 +31,15 @@ except ImportError:
# Maybe someone else already finished compiling it while we were # Maybe someone else already finished compiling it while we were
# waiting for the lock? # waiting for the lock?
try: try:
if need_reload: if force_compile:
raise ImportError()
if _need_reload:
# The module was successfully imported earlier: we need to # The module was successfully imported earlier: we need to
# reload it to check if the version was updated. # reload it to check if the version was updated.
reload(lazylinker_ext) reload(lazylinker_ext)
else: else:
import lazylinker_ext import lazylinker_ext
need_reload = True _need_reload = True
if version != getattr(lazylinker_ext, '_version', None): if version != getattr(lazylinker_ext, '_version', None):
raise ImportError() raise ImportError()
except ImportError: except ImportError:
...@@ -67,4 +74,4 @@ except ImportError: ...@@ -67,4 +74,4 @@ except ImportError:
release_lock() release_lock()
from lazylinker_ext.lazylinker_ext import * from lazylinker_ext.lazylinker_ext import *
assert version == get_version() assert force_compile or (version == get_version())
...@@ -61,6 +61,10 @@ class VM(object): ...@@ -61,6 +61,10 @@ class VM(object):
self.call_counts = [0]*len(nodes) self.call_counts = [0]*len(nodes)
self.call_times = [0]*len(nodes) self.call_times = [0]*len(nodes)
self.time_thunks = False self.time_thunks = False
# This variable (self.need_update_inputs) is overshadowed by
# CLazyLinker in CVM which has an attribute of the same name that
# defaults to 0 (aka False).
self.need_update_inputs = True self.need_update_inputs = True
def __call__(self): def __call__(self):
...@@ -405,6 +409,7 @@ class VM_Linker(link.LocalLinker): ...@@ -405,6 +409,7 @@ class VM_Linker(link.LocalLinker):
self.allow_gc = allow_gc self.allow_gc = allow_gc
self.use_cloop = use_cloop self.use_cloop = use_cloop
self.callback = callback self.callback = callback
self.updated_vars = {}
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling = []):
""" """
...@@ -420,6 +425,14 @@ class VM_Linker(link.LocalLinker): ...@@ -420,6 +425,14 @@ class VM_Linker(link.LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def accept_var_updates(self, updated_vars):
self.updated_vars = updated_vars
# This method simply records in the linker which variables have update
# expressions. It does not imply that the linker will actually
# implement these updates (see need_update_inputs). This mechanism is
# admittedly confusing, and it could use some cleaning up. The base
# Linker object should probably go away completely.
def make_vm(self, nodes, thunks, def make_vm(self, nodes, thunks,
input_storage, output_storage, storage_map, input_storage, output_storage, storage_map,
post_thunk_clear, post_thunk_clear,
...@@ -559,7 +572,6 @@ class VM_Linker(link.LocalLinker): ...@@ -559,7 +572,6 @@ class VM_Linker(link.LocalLinker):
def make_all(self, profiler = None, input_storage = None, def make_all(self, profiler = None, input_storage = None,
output_storage = None, output_storage = None,
): ):
expanded_inputs=self.expanded_inputs # hacky argumentpassing workaround
env = self.env env = self.env
order = list(env.toposort()) order = list(env.toposort())
no_recycling = self.no_recycling no_recycling = self.no_recycling
...@@ -590,24 +602,12 @@ class VM_Linker(link.LocalLinker): ...@@ -590,24 +602,12 @@ class VM_Linker(link.LocalLinker):
else: else:
post_thunk_clear = None post_thunk_clear = None
# calculate the update_storage map whose keys are shared var inputs
# and whose values are the outputs that hold their updates
updated_vars = {}
if expanded_inputs:
# Update the inputs that have an update function
potential_values = list(env.outputs)
assert len(expanded_inputs)==len(env.inputs)
for e_input, ivar in reversed(zip(expanded_inputs, env.inputs)):
if e_input.update is not None:
updated_vars[ivar] = potential_values.pop()
vm = self.make_vm(order, thunks, vm = self.make_vm(order, thunks,
input_storage, output_storage, storage_map, input_storage, output_storage, storage_map,
post_thunk_clear, post_thunk_clear,
computed, computed,
compute_map, compute_map,
updated_vars self.updated_vars
) )
return (vm, return (vm,
......
...@@ -191,7 +191,7 @@ class RandomFunction(gof.Op): ...@@ -191,7 +191,7 @@ class RandomFunction(gof.Op):
# Numbers are drawn from r if self.inplace is True, and from a copy of r if # Numbers are drawn from r if self.inplace is True, and from a copy of r if
# self.inplace is False # self.inplace is False
r, shape, args = inputs[0], inputs[1], inputs[2:] r, shape, args = inputs[0], inputs[1], inputs[2:]
assert type(r) == numpy.random.RandomState assert type(r) == numpy.random.RandomState, (type(r), r)
r_orig = r r_orig = r
# If shape == [], that means no shape is enforced, and numpy is # If shape == [], that means no shape is enforced, and numpy is
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论