提交 031400a6 authored 作者: Jakub Sygnowski's avatar Jakub Sygnowski

partial evaluation in CVM

上级 a3bbfb8c
......@@ -789,14 +789,35 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
{
CLazyLinker * self = (CLazyLinker*)_self;
static char *kwlist[] = {
(char*)"time_thunks",
(char *)"time_thunks",
(char *)"n_calls",
(char *)"output_subset",
NULL};
int n_calls=1;
if (! PyArg_ParseTupleAndKeywords(args, kwds, "|ii", kwlist,
PyObject *output_subset_ptr = NULL;
if (! PyArg_ParseTupleAndKeywords(args, kwds, "|iiO", kwlist,
&self->do_timing,
&n_calls))
&n_calls,
&output_subset_ptr))
return NULL;
// parse an output_subset list
// it is stored as a bool list of length n_output_vars: calculate a var or not
char *output_subset = NULL;
int output_subset_size = -1;
if (output_subset_ptr != NULL)
{
assert (PyList_Check(output_subset_ptr));
output_subset_size = PyList_Size(output_subset_ptr);
output_subset = (char*)calloc(self->n_output_vars, sizeof(char));
for (int it = 0; it < output_subset_size; ++it)
{
PyObject *elem = PyList_GetItem(output_subset_ptr, it);
assert (PyInt_Check(elem));
output_subset[PyInt_AsLong(elem)] = 1;
}
}
int err = 0;
self->position_of_error = -1;
// create constants used to fill the var_compute_cells
......@@ -834,21 +855,26 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
}
for (int i = 0; i < self->n_output_vars && (!err); ++i)
{
if (output_subset == NULL || output_subset[i] == 1)
{
err = lazy_rec_eval(self, self->output_vars[i], one, zero);
}
}
if (!err)
{
// save references to outputs prior to updating storage containers
assert (self->n_output_vars >= self->n_updates);
assert (output_subset_size < 0 || output_subset_size >= self->n_updates);
Py_DECREF(rval);
rval = PyList_New(self->n_output_vars);
for (int i = 0; i < (self->n_output_vars); ++i)
{
Py_ssize_t src = self->output_vars[i];
PyObject * item = PyList_GetItem(self->var_value_cells[src], 0);
if (self->var_computed[src] != 1)
if ((output_subset == NULL || output_subset[i]) &&
self->var_computed[src] != 1)
{
err = 1;
PyErr_Format(PyExc_AssertionError,
......@@ -901,6 +927,9 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
PyList_SetItem(self->var_value_cells[i], 0, Py_None);
}
}
if (output_subset != NULL)
free(output_subset);
Py_DECREF(one);
Py_DECREF(zero);
if (err)
......@@ -1014,7 +1043,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = {
static PyObject * get_version(PyObject *dummy, PyObject *args)
{
PyObject *result = PyFloat_FromDouble(0.21);
PyObject *result = PyFloat_FromDouble(0.211);
return result;
}
......
......@@ -15,7 +15,7 @@ from theano.gof import cmodule
_logger = logging.getLogger('theano.gof.lazylinker_c')
force_compile = False
version = 0.21 # must match constant returned in function get_version()
version = 0.211 # must match constant returned in function get_version()
lazylinker_ext = None
......
......@@ -197,15 +197,20 @@ def test_speed_lazy():
def test_partial_function():
import numpy as np
from theano.tests import unittest_tools as utt
def check_partial_function(linker_name):
x = tensor.scalar('input')
y = x ** 2
f = theano.function([x], [y + 7, y - 9, y / 14.], mode=Mode(
optimizer=None, linker=vm.VM_Linker(allow_partial_eval=True)))
optimizer=None, linker=linker_name))
assert f(3, output_subset=[0, 1, 2]) == f(3)
assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]]
utt.assert_allclose(f(5), np.array([32., 16., 1.7857142857142858]))
check_partial_function(vm.VM_Linker(allow_partial_eval=True))
check_partial_function('cvm')
def test_partial_function_output_keys():
x = tensor.scalar('input')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论