提交 448b8070 authored 作者: James Bergstra's avatar James Bergstra

CLazyLinker returns pre-update outputs (bugfix)

上级 87c7d2fa
...@@ -631,7 +631,7 @@ class Function(object): ...@@ -631,7 +631,7 @@ class Function(object):
# Do the actual work # Do the actual work
t0_fn = time.time() t0_fn = time.time()
try: try:
self.fn() outputs = self.fn()
except Exception: except Exception:
if hasattr(self.fn, 'position_of_error'): if hasattr(self.fn, 'position_of_error'):
# this is a new vm-provided function # this is a new vm-provided function
...@@ -648,7 +648,9 @@ class Function(object): ...@@ -648,7 +648,9 @@ class Function(object):
profile.vm_call_time += dt_fn profile.vm_call_time += dt_fn
# Retrieve the values that were computed # Retrieve the values that were computed
if outputs is None:
outputs = [x.data for x in self.output_storage] outputs = [x.data for x in self.output_storage]
assert len(outputs) == len(self.output_storage)
# Remove internal references to required inputs. # Remove internal references to required inputs.
# These cannot be re-used anyway. # These cannot be re-used anyway.
......
...@@ -732,6 +732,8 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds) ...@@ -732,6 +732,8 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
self->position_of_error = -1; self->position_of_error = -1;
PyObject * one = PyInt_FromLong(1); PyObject * one = PyInt_FromLong(1);
PyObject * zero = PyInt_FromLong(0); PyObject * zero = PyInt_FromLong(0);
Py_INCREF(Py_None);
PyObject * rval = Py_None;
//clear storage of pre_call_clear elements //clear storage of pre_call_clear elements
for (int call_i = 0; call_i < n_calls && (!err); ++call_i) for (int call_i = 0; call_i < n_calls && (!err); ++call_i)
{ {
...@@ -764,6 +766,21 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds) ...@@ -764,6 +766,21 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
err = lazy_rec_eval(self, self->output_vars[i], one, zero); err = lazy_rec_eval(self, self->output_vars[i], one, zero);
} }
// save references to outputs prior to updating storage containers
if ((call_i + 1) == n_calls)
{
assert (self->n_output_vars > self->n_updates);
Py_DECREF(rval);
rval = PyList_New(self->n_output_vars);
for (int i = 0; i < (self->n_output_vars); ++i)
{
Py_ssize_t src = self->output_vars[i];
PyObject * item = PyList_GetItem(self->var_value_cells[src], 0);
Py_INCREF(item);
PyList_SetItem(rval, i, item);
}
}
for (int i = 0; i < self->n_updates; ++i) for (int i = 0; i < self->n_updates; ++i)
{ {
Py_ssize_t dst = self->update_storage[2*i]; Py_ssize_t dst = self->update_storage[2*i];
...@@ -777,9 +794,12 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds) ...@@ -777,9 +794,12 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
} }
Py_DECREF(one); Py_DECREF(one);
Py_DECREF(zero); Py_DECREF(zero);
if (err) return NULL; if (err)
Py_INCREF(Py_None); {
return Py_None; Py_DECREF(rval);
return NULL;
}
return rval;
} }
#if 0 #if 0
...@@ -853,7 +873,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = { ...@@ -853,7 +873,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = {
static PyObject * get_version(PyObject *dummy, PyObject *args) static PyObject * get_version(PyObject *dummy, PyObject *args)
{ {
PyObject *result = PyFloat_FromDouble(0.1); PyObject *result = PyFloat_FromDouble(0.11);
return result; return result;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论