提交 148814bf authored 作者: nouiz's avatar nouiz

Merge pull request #311 from jaberg/cvm_runs_testsuite

Cvm runs testsuite
......@@ -21,6 +21,7 @@ from theano.configdefaults import config
import logging
_logger = logging.getLogger('theano.compile.function_module')
def alias_root(v):
"""Return the variable to which v is aliased by view_maps and destroy_maps"""
if v.owner is None: return v
......@@ -35,6 +36,7 @@ def alias_root(v):
else:
return v
def view_tree_set(v, treeset):
"""Add to `treeset` all variables that are views of v, given that v is not a view"""
treeset.add(v)
......@@ -48,6 +50,7 @@ def view_tree_set(v, treeset):
if cl.outputs[opos] not in treeset:
view_tree_set(cl.outputs[opos], treeset)
def infer_reuse_pattern(env, outputs_to_disown):
"""
Given an env and a list of variables, returns the list or set of all variables which may
......@@ -65,6 +68,23 @@ def infer_reuse_pattern(env, outputs_to_disown):
return rval
def env_updated_vars(env, expanded_inputs):
"""
Reconstruct the full "updates" dictionary, mapping from Env input
variables to the env outputs that will replace their values.
:rtype: dict variable -> variable
"""
updated_vars = {}
potential_values = list(env.outputs) # copy the list
if len(expanded_inputs) != len(env.inputs):
raise ValueError('expanded_inputs must match len(env.inputs)')
for e_input, ivar in reversed(zip(expanded_inputs, env.inputs)):
if e_input.update is not None:
updated_vars[ivar] = potential_values.pop()
return updated_vars
class Supervisor:
"""
Listener for Env events which makes sure that no operation overwrites the
......@@ -550,6 +570,7 @@ class Function(object):
# Set keyword arguments
if kwargs: # for speed, skip the iteritems for empty kwargs
for k, arg in kwargs.iteritems():
self[k] = arg
......@@ -610,7 +631,7 @@ class Function(object):
# Do the actual work
t0_fn = time.time()
try:
self.fn()
outputs = self.fn()
except Exception:
if hasattr(self.fn, 'position_of_error'):
# this is a new vm-provided function
......@@ -627,7 +648,9 @@ class Function(object):
profile.vm_call_time += dt_fn
# Retrieve the values that were computed
if outputs is None:
outputs = [x.data for x in self.output_storage]
assert len(outputs) == len(self.output_storage)
# Remove internal references to required inputs.
# These cannot be re-used anyway.
......@@ -1027,8 +1050,10 @@ class FunctionMaker(object):
else:
self.linker = linker.accept(env)
#hacky thing so VMLinker
self.linker.expanded_inputs = expanded_inputs
if hasattr(linker, 'accept_var_updates'):
# hacky thing so VMLinker knows about updates
self.linker.accept_var_updates(
env_updated_vars(env, expanded_inputs))
self.indices = indices
self.inputs = inputs
......
......@@ -28,9 +28,9 @@ class T_bunch_of_modes(unittest.TestCase):
x = T.matrix()
y = T.vector()
f = theano.function([x,y], x+y, mode=mode)
f = theano.function([x, y], x + y, mode=mode)
# test that it runs something
f([[1,2],[3,4]], [5, 6])
f([[1, 2], [3, 4]], [5, 6])
linker_classes_involved.append(f.maker.mode.linker.__class__)
print 'MODE:', mode, f.maker.mode.linker, 'stop'
# regression check:
......
......@@ -470,16 +470,25 @@ static PyObject * pycall(CLazyLinker * self, Py_ssize_t node_idx, int verbose)
double t0 = pytime(NULL);
if (verbose) fprintf(stderr, "calling via Python (node %i)\n", (int)node_idx);
rval = PyObject_CallObject(thunk, NULL);
if (rval)
{
double t1 = pytime(NULL);
double ti = PyFloat_AsDouble(PyList_GetItem(self->call_times, node_idx));
PyList_SetItem(self->call_times, node_idx, PyFloat_FromDouble(t1 - t0 + ti));
double ti = PyFloat_AsDouble(
PyList_GetItem(self->call_times, node_idx));
PyList_SetItem(self->call_times, node_idx,
PyFloat_FromDouble(t1 - t0 + ti));
PyObject * count = PyList_GetItem(self->call_counts, node_idx);
long icount = PyInt_AsLong(count);
PyList_SetItem(self->call_counts, node_idx, PyInt_FromLong(icount+1));
PyList_SetItem(self->call_counts, node_idx,
PyInt_FromLong(icount + 1));
}
}
else
{
if (verbose) fprintf(stderr, "calling via Python (node %i)\n", (int)node_idx);
if (verbose)
{
fprintf(stderr, "calling via Python (node %i)\n", (int)node_idx);
}
rval = PyObject_CallObject(thunk, NULL);
}
return rval;
......@@ -730,8 +739,13 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
return NULL;
int err = 0;
self->position_of_error = -1;
// create constants used to fill the var_compute_cells
PyObject * one = PyInt_FromLong(1);
PyObject * zero = PyInt_FromLong(0);
// pre-allocate our return value
Py_INCREF(Py_None);
PyObject * rval = Py_None;
//clear storage of pre_call_clear elements
for (int call_i = 0; call_i < n_calls && (!err); ++call_i)
{
......@@ -764,22 +778,41 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
err = lazy_rec_eval(self, self->output_vars[i], one, zero);
}
if (!err)
{
// save references to outputs prior to updating storage containers
if ((call_i + 1) == n_calls)
{
assert (self->n_output_vars >= self->n_updates);
Py_DECREF(rval);
rval = PyList_New(self->n_output_vars);
for (int i = 0; i < (self->n_output_vars); ++i)
{
Py_ssize_t src = self->output_vars[i];
PyObject * item = PyList_GetItem(self->var_value_cells[src], 0);
Py_INCREF(item);
PyList_SetItem(rval, i, item);
}
}
for (int i = 0; i < self->n_updates; ++i)
{
Py_ssize_t dst = self->update_storage[2*i];
Py_ssize_t src = self->update_storage[2*i+1];
PyObject* tmp = PyList_GetItem(self->var_value_cells[src], 0);
Py_INCREF(Py_None);
Py_INCREF(tmp);
PyList_SetItem(self->var_value_cells[dst], 0, tmp);
PyList_SetItem(self->var_value_cells[src], 0, Py_None);
}
}
}
Py_DECREF(one);
Py_DECREF(zero);
if (err) return NULL;
Py_INCREF(Py_None);
return Py_None;
if (err)
{
Py_DECREF(rval);
return NULL;
}
return rval;
}
#if 0
......@@ -853,7 +886,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = {
static PyObject * get_version(PyObject *dummy, PyObject *args)
{
PyObject *result = PyFloat_FromDouble(0.1);
PyObject *result = PyFloat_FromDouble(0.13);
return result;
}
......
......@@ -12,12 +12,17 @@ _logger = logging.getLogger('theano.gof.lazylinker_c')
if config.compiledir not in sys.path:
sys.path.append(config.compiledir)
version = 0.1 # must match constant returned in function get_version()
force_compile = False
version = 0.13 # must match constant returned in function get_version()
need_reload = False
try:
_need_reload = False
if force_compile:
raise ImportError()
else:
import lazylinker_ext
need_reload = True
_need_reload = True
if version != getattr(lazylinker_ext, '_version', None):
raise ImportError()
except ImportError:
......@@ -26,13 +31,15 @@ except ImportError:
# Maybe someone else already finished compiling it while we were
# waiting for the lock?
try:
if need_reload:
if force_compile:
raise ImportError()
if _need_reload:
# The module was successfully imported earlier: we need to
# reload it to check if the version was updated.
reload(lazylinker_ext)
else:
import lazylinker_ext
need_reload = True
_need_reload = True
if version != getattr(lazylinker_ext, '_version', None):
raise ImportError()
except ImportError:
......@@ -67,4 +74,4 @@ except ImportError:
release_lock()
from lazylinker_ext.lazylinker_ext import *
assert version == get_version()
assert force_compile or (version == get_version())
......@@ -61,6 +61,10 @@ class VM(object):
self.call_counts = [0]*len(nodes)
self.call_times = [0]*len(nodes)
self.time_thunks = False
# This variable (self.need_update_inputs) is overshadowed by
# CLazyLinker in CVM which has an attribute of the same name that
# defaults to 0 (aka False).
self.need_update_inputs = True
def __call__(self):
......@@ -405,6 +409,7 @@ class VM_Linker(link.LocalLinker):
self.allow_gc = allow_gc
self.use_cloop = use_cloop
self.callback = callback
self.updated_vars = {}
def accept(self, env, no_recycling = []):
"""
......@@ -420,6 +425,14 @@ class VM_Linker(link.LocalLinker):
self.no_recycling = no_recycling
return self
def accept_var_updates(self, updated_vars):
self.updated_vars = updated_vars
# This method simply records in the linker which variables have update
# expressions. It does not imply that the linker will actually
# implement these updates (see need_update_inputs). This mechanism is
# admittedly confusing, and it could use some cleaning up. The base
# Linker object should probably go away completely.
def make_vm(self, nodes, thunks,
input_storage, output_storage, storage_map,
post_thunk_clear,
......@@ -559,7 +572,6 @@ class VM_Linker(link.LocalLinker):
def make_all(self, profiler = None, input_storage = None,
output_storage = None,
):
expanded_inputs=self.expanded_inputs # hacky argumentpassing workaround
env = self.env
order = list(env.toposort())
no_recycling = self.no_recycling
......@@ -590,24 +602,12 @@ class VM_Linker(link.LocalLinker):
else:
post_thunk_clear = None
# calculate the update_storage map whose keys are shared var inputs
# and whose values are the outputs that hold their updates
updated_vars = {}
if expanded_inputs:
# Update the inputs that have an update function
potential_values = list(env.outputs)
assert len(expanded_inputs)==len(env.inputs)
for e_input, ivar in reversed(zip(expanded_inputs, env.inputs)):
if e_input.update is not None:
updated_vars[ivar] = potential_values.pop()
vm = self.make_vm(order, thunks,
input_storage, output_storage, storage_map,
post_thunk_clear,
computed,
compute_map,
updated_vars
self.updated_vars
)
return (vm,
......
......@@ -191,7 +191,7 @@ class RandomFunction(gof.Op):
# Numbers are drawn from r if self.inplace is True, and from a copy of r if
# self.inplace is False
r, shape, args = inputs[0], inputs[1], inputs[2:]
assert type(r) == numpy.random.RandomState
assert type(r) == numpy.random.RandomState, (type(r), r)
r_orig = r
# If shape == [], that means no shape is enforced, and numpy is
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论