提交 678816fa authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix updates in CVM

Changed the semantics of update_storage in CVM, so it contains only the input storage cells to update, and ensure they are in the same order as the computed outputs.
上级 d1e06e3a
...@@ -103,7 +103,7 @@ typedef struct { ...@@ -103,7 +103,7 @@ typedef struct {
Py_ssize_t * node_n_prereqs; Py_ssize_t * node_n_prereqs;
Py_ssize_t ** node_prereqs; Py_ssize_t ** node_prereqs;
Py_ssize_t * update_storage; // dst0, src0, dst1, src1, ... cells to switch after a call Py_ssize_t * update_storage; // input cells to update with the last outputs in output_vars
Py_ssize_t n_updates; Py_ssize_t n_updates;
void ** thunk_cptr_fn; void ** thunk_cptr_fn;
...@@ -467,8 +467,6 @@ CLazyLinker_init(CLazyLinker *self, PyObject *args, PyObject *kwds) ...@@ -467,8 +467,6 @@ CLazyLinker_init(CLazyLinker *self, PyObject *args, PyObject *kwds)
if (unpack_list_of_ssize_t(update_storage, &self->update_storage, &self->n_updates, if (unpack_list_of_ssize_t(update_storage, &self->update_storage, &self->n_updates,
"updates_storage")) "updates_storage"))
return -1; return -1;
assert((self->n_updates % 2) == 0);
self->n_updates /= 2;
return 0; return 0;
} }
static void set_position_of_error(CLazyLinker * self, int owner_idx) static void set_position_of_error(CLazyLinker * self, int owner_idx)
...@@ -841,26 +839,23 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds) ...@@ -841,26 +839,23 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
if (!err) if (!err)
{ {
// save references to outputs prior to updating storage containers // save references to outputs prior to updating storage containers
if ((call_i + 1) == n_calls) assert (self->n_output_vars >= self->n_updates);
Py_DECREF(rval);
rval = PyList_New(self->n_output_vars);
for (int i = 0; i < (self->n_output_vars); ++i)
{ {
assert (self->n_output_vars >= self->n_updates); Py_ssize_t src = self->output_vars[i];
Py_DECREF(rval); PyObject * item = PyList_GetItem(self->var_value_cells[src], 0);
rval = PyList_New(self->n_output_vars); Py_INCREF(item);
for (int i = 0; i < (self->n_output_vars); ++i) PyList_SetItem(rval, i, item);
{
Py_ssize_t src = self->output_vars[i];
PyObject * item = PyList_GetItem(self->var_value_cells[src], 0);
Py_INCREF(item);
PyList_SetItem(rval, i, item);
}
} }
// Update the inputs that have an update rule
for (int i = 0; i < self->n_updates; ++i) for (int i = 0; i < self->n_updates; ++i)
{ {
Py_ssize_t dst = self->update_storage[2*i]; PyObject* tmp = PyList_GetItem(rval, self->n_output_vars - self->n_updates + i);
Py_ssize_t src = self->update_storage[2*i+1];
PyObject* tmp = PyList_GetItem(self->var_value_cells[src], 0);
Py_INCREF(tmp); Py_INCREF(tmp);
Py_ssize_t dst = self->update_storage[i];
PyList_SetItem(self->var_value_cells[dst], 0, tmp); PyList_SetItem(self->var_value_cells[dst], 0, tmp);
} }
} }
...@@ -973,7 +968,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = { ...@@ -973,7 +968,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = {
static PyObject * get_version(PyObject *dummy, PyObject *args) static PyObject * get_version(PyObject *dummy, PyObject *args)
{ {
PyObject *result = PyFloat_FromDouble(0.18); PyObject *result = PyFloat_FromDouble(0.19);
return result; return result;
} }
......
...@@ -13,7 +13,7 @@ if config.compiledir not in sys.path: ...@@ -13,7 +13,7 @@ if config.compiledir not in sys.path:
sys.path.append(config.compiledir) sys.path.append(config.compiledir)
force_compile = False force_compile = False
version = 0.18 # must match constant returned in function get_version() version = 0.19 # must match constant returned in function get_version()
try: try:
......
...@@ -707,11 +707,19 @@ class VM_Linker(link.LocalLinker): ...@@ -707,11 +707,19 @@ class VM_Linker(link.LocalLinker):
prereq_var_idxs.sort() # TODO: why sort? prereq_var_idxs.sort() # TODO: why sort?
node_prereqs.append(prereq_var_idxs) node_prereqs.append(prereq_var_idxs)
# Builds the list of input storage to update (according to update
# rules) when the outputs are computed.
# They are in the same order as the second part of output_vars
# (output_vars contains first the returned outputs, then the
# values of the update expressions).
update_storage = [] update_storage = []
update_in_from_out = {}
for (ivar, ovar) in updated_vars.items(): for (ivar, ovar) in updated_vars.items():
if ivar != ovar: if ivar != ovar:
update_storage.append(vars_idx[ivar]) # dst update_in_from_out[vars_idx[ovar]] = vars_idx[ivar]
update_storage.append(vars_idx[ovar]) # src for oidx in output_vars:
if oidx in update_in_from_out:
update_storage.append(update_in_from_out[oidx])
c0 = sys.getrefcount(node_n_inputs) c0 = sys.getrefcount(node_n_inputs)
vm = CVM( vm = CVM(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论