提交 81e52783 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4669 from sygi/partial-eval-cvm

Partial evaluation in CVM
...@@ -735,9 +735,13 @@ class Function(object): ...@@ -735,9 +735,13 @@ class Function(object):
kwargs : dict kwargs : dict
The function inputs can be passed as keyword argument. For this, use The function inputs can be passed as keyword argument. For this, use
the name of the input or the input instance as the key. the name of the input or the input instance as the key.
Keyword argument ``output_subset`` is a list of either indices of the Keyword argument ``output_subset`` is a list of either indices of the
function's outputs or the keys belonging to the `output_keys` dict function's outputs or the keys belonging to the `output_keys` dict
and represent outputs that are requested to be calculated. and represent outputs that are requested to be calculated. Regardless
of the presence of ``output_subset``, the updates are always calculated
and processed. To disable the updates, you should use the ``copy``
method with ``delete_updates=True``.
Returns Returns
------- -------
......
...@@ -789,15 +789,47 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds) ...@@ -789,15 +789,47 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
{ {
CLazyLinker * self = (CLazyLinker*)_self; CLazyLinker * self = (CLazyLinker*)_self;
static char *kwlist[] = { static char *kwlist[] = {
(char*)"time_thunks", (char *)"time_thunks",
(char *)"n_calls", (char *)"n_calls",
(char *)"output_subset",
NULL}; NULL};
int n_calls=1; int n_calls=1;
if (! PyArg_ParseTupleAndKeywords(args, kwds, "|ii", kwlist, PyObject *output_subset_ptr = NULL;
if (! PyArg_ParseTupleAndKeywords(args, kwds, "|iiO", kwlist,
&self->do_timing, &self->do_timing,
&n_calls)) &n_calls,
&output_subset_ptr))
return NULL; return NULL;
int err = 0; int err = 0;
// parse an output_subset list
// it is stored as a bool list of length n_output_vars: calculate a var or not
char *output_subset = NULL;
int output_subset_size = -1;
if (output_subset_ptr != NULL)
{
if (! PyList_Check(output_subset_ptr))
{
err = 1;
PyErr_SetString(PyExc_RuntimeError, "Output_subset is not a list");
}
else
{
output_subset_size = PyList_Size(output_subset_ptr);
output_subset = (char*)calloc(self->n_output_vars, sizeof(char));
for (int it = 0; it < output_subset_size; ++it)
{
PyObject *elem = PyList_GetItem(output_subset_ptr, it);
if (! PyInt_Check(elem))
{
err = 1;
PyErr_SetString(PyExc_RuntimeError, "Some elements of output_subset list are not int");
}
output_subset[PyInt_AsLong(elem)] = 1;
}
}
}
self->position_of_error = -1; self->position_of_error = -1;
// create constants used to fill the var_compute_cells // create constants used to fill the var_compute_cells
PyObject * one = PyInt_FromLong(1); PyObject * one = PyInt_FromLong(1);
...@@ -833,10 +865,14 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds) ...@@ -833,10 +865,14 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
} }
} }
int first_updated = self->n_output_vars - self->n_updates;
for (int i = 0; i < self->n_output_vars && (!err); ++i) for (int i = 0; i < self->n_output_vars && (!err); ++i)
{
if (i >= first_updated || output_subset == NULL || output_subset[i] == 1)
{ {
err = lazy_rec_eval(self, self->output_vars[i], one, zero); err = lazy_rec_eval(self, self->output_vars[i], one, zero);
} }
}
if (!err) if (!err)
{ {
...@@ -848,7 +884,8 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds) ...@@ -848,7 +884,8 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
{ {
Py_ssize_t src = self->output_vars[i]; Py_ssize_t src = self->output_vars[i];
PyObject * item = PyList_GetItem(self->var_value_cells[src], 0); PyObject * item = PyList_GetItem(self->var_value_cells[src], 0);
if (self->var_computed[src] != 1) if ((output_subset == NULL || output_subset[i]) &&
self->var_computed[src] != 1)
{ {
err = 1; err = 1;
PyErr_Format(PyExc_AssertionError, PyErr_Format(PyExc_AssertionError,
...@@ -901,6 +938,9 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds) ...@@ -901,6 +938,9 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
PyList_SetItem(self->var_value_cells[i], 0, Py_None); PyList_SetItem(self->var_value_cells[i], 0, Py_None);
} }
} }
if (output_subset != NULL)
free(output_subset);
Py_DECREF(one); Py_DECREF(one);
Py_DECREF(zero); Py_DECREF(zero);
if (err) if (err)
...@@ -1014,7 +1054,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = { ...@@ -1014,7 +1054,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = {
static PyObject * get_version(PyObject *dummy, PyObject *args) static PyObject * get_version(PyObject *dummy, PyObject *args)
{ {
PyObject *result = PyFloat_FromDouble(0.21); PyObject *result = PyFloat_FromDouble(0.211);
return result; return result;
} }
......
...@@ -15,7 +15,7 @@ from theano.gof import cmodule ...@@ -15,7 +15,7 @@ from theano.gof import cmodule
_logger = logging.getLogger('theano.gof.lazylinker_c') _logger = logging.getLogger('theano.gof.lazylinker_c')
force_compile = False force_compile = False
version = 0.21 # must match constant returned in function get_version() version = 0.211 # must match constant returned in function get_version()
lazylinker_ext = None lazylinker_ext = None
......
...@@ -197,24 +197,54 @@ def test_speed_lazy(): ...@@ -197,24 +197,54 @@ def test_speed_lazy():
def test_partial_function(): def test_partial_function():
import numpy as np import numpy as np
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
def check_partial_function(linker_name):
x = tensor.scalar('input') x = tensor.scalar('input')
y = x ** 2 y = x ** 2
f = theano.function([x], [y + 7, y - 9, y / 14.], mode=Mode( f = theano.function([x], [y + 7, y - 9, y / 14.], mode=Mode(
optimizer=None, linker=vm.VM_Linker(allow_partial_eval=True))) optimizer=None, linker=linker_name))
assert f(3, output_subset=[0, 1, 2]) == f(3) assert f(3, output_subset=[0, 1, 2]) == f(3)
assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]] assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]]
utt.assert_allclose(f(5), np.array([32., 16., 1.7857142857142858])) utt.assert_allclose(f(5), np.array([32., 16., 1.7857142857142858]))
check_partial_function(vm.VM_Linker(allow_partial_eval=True, use_cloop=False))
check_partial_function('cvm')
def test_partial_function_with_output_keys():
def test_partial_function_output_keys(): def check_partial_function_output_keys(linker_name):
x = tensor.scalar('input') x = tensor.scalar('input')
y = 3 * x y = 3 * x
f = theano.function([x], {'a': y * 5, 'b': y - 7}, mode=Mode( f = theano.function([x], {'a': y * 5, 'b': y - 7}, mode=Mode(
optimizer=None, linker=vm.VM_Linker(allow_partial_eval=True))) optimizer=None, linker=linker_name))
assert f(5, output_subset=['a'])['a'] == f(5)['a'] assert f(5, output_subset=['a'])['a'] == f(5)['a']
check_partial_function_output_keys(vm.VM_Linker(allow_partial_eval=True, use_cloop=False))
check_partial_function_output_keys('cvm')
def test_partial_function_with_updates():
def check_updates(linker_name):
x = tensor.lscalar('input')
y = theano.shared(1, name='global')
f = theano.function([x], [x, x + 34], updates=[(y, x + 1)], mode=Mode(
optimizer=None, linker=linker_name))
g = theano.function([x], [x - 6], updates=[(y, y + 3)], mode=Mode(
optimizer=None, linker=linker_name))
assert f(3, output_subset=[]) == []
assert y.get_value() == 4
assert g(30, output_subset=[0]) == [24]
assert g(40, output_subset=[]) == []
assert y.get_value() == 10
check_updates(vm.VM_Linker(allow_partial_eval=True, use_cloop=False))
check_updates('cvm')
def test_allow_gc_cvm(): def test_allow_gc_cvm():
mode = theano.config.mode mode = theano.config.mode
......
...@@ -332,7 +332,8 @@ class Stack(VM): ...@@ -332,7 +332,8 @@ class Stack(VM):
def __init__(self, nodes, thunks, pre_call_clear, def __init__(self, nodes, thunks, pre_call_clear,
storage_map, compute_map, fgraph, allow_gc, storage_map, compute_map, fgraph, allow_gc,
dependencies=None, callback=None, callback_input=None): n_updates, dependencies=None, callback=None,
callback_input=None):
super(Stack, self).__init__(nodes, thunks, pre_call_clear) super(Stack, self).__init__(nodes, thunks, pre_call_clear)
self.allow_gc = allow_gc self.allow_gc = allow_gc
...@@ -346,6 +347,7 @@ class Stack(VM): ...@@ -346,6 +347,7 @@ class Stack(VM):
self.node_idx = node_idx = {} self.node_idx = node_idx = {}
self.callback = callback self.callback = callback
self.callback_input = callback_input self.callback_input = callback_input
self.n_updates = n_updates
ords = fgraph.orderings() ords = fgraph.orderings()
...@@ -417,6 +419,9 @@ class Stack(VM): ...@@ -417,6 +419,9 @@ class Stack(VM):
# apply_stack contains nodes # apply_stack contains nodes
if output_subset is not None: if output_subset is not None:
first_updated = len(self.outputs) - self.n_updates
output_subset = output_subset + list(range(first_updated,
len(self.outputs)))
apply_stack =\ apply_stack =\
[self.outputs[i].owner for i in output_subset [self.outputs[i].owner for i in output_subset
if self.outputs[i].owner] if self.outputs[i].owner]
...@@ -425,7 +430,7 @@ class Stack(VM): ...@@ -425,7 +430,7 @@ class Stack(VM):
last_apply_stack_len = -1 last_apply_stack_len = -1
# This record all function inputs/shared varibles and constants # This record all function inputs/shared variables and constants
for var, data in iteritems(self.storage_map): for var, data in iteritems(self.storage_map):
if data[0] is None: if data[0] is None:
continue continue
...@@ -842,7 +847,7 @@ class VM_Linker(link.LocalLinker): ...@@ -842,7 +847,7 @@ class VM_Linker(link.LocalLinker):
if (self.callback is not None or self.callback_input is not None or if (self.callback is not None or self.callback_input is not None or
(config.profile and config.profile_memory) or (config.profile and config.profile_memory) or
self.allow_partial_eval): (self.allow_partial_eval and not self.use_cloop)):
if self.use_cloop and (self.callback is not None or if self.use_cloop and (self.callback is not None or
self.callback_input is not None): self.callback_input is not None):
...@@ -850,9 +855,9 @@ class VM_Linker(link.LocalLinker): ...@@ -850,9 +855,9 @@ class VM_Linker(link.LocalLinker):
if self.use_cloop and config.profile_memory: if self.use_cloop and config.profile_memory:
warnings.warn( warnings.warn(
'CVM does not support memory profile, using Stack VM.') 'CVM does not support memory profile, using Stack VM.')
if self.use_cloop and self.allow_partial_eval: if not self.use_cloop and self.allow_partial_eval:
warnings.warn( warnings.warn(
'CVM does not support partial evaluation yet, ' 'LoopGC does not support partial evaluation, '
'using Stack VM.') 'using Stack VM.')
# Needed for allow_gc=True, profiling and storage_map reuse # Needed for allow_gc=True, profiling and storage_map reuse
deps = self.compute_gc_dependencies(storage_map) deps = self.compute_gc_dependencies(storage_map)
...@@ -860,6 +865,7 @@ class VM_Linker(link.LocalLinker): ...@@ -860,6 +865,7 @@ class VM_Linker(link.LocalLinker):
nodes, thunks, pre_call_clear, nodes, thunks, pre_call_clear,
storage_map, compute_map, storage_map, compute_map,
self.fgraph, self.allow_gc, self.fgraph, self.allow_gc,
len(updated_vars),
dependencies=deps, dependencies=deps,
callback=self.callback, callback=self.callback,
callback_input=self.callback_input) callback_input=self.callback_input)
...@@ -1000,7 +1006,8 @@ class VM_Linker(link.LocalLinker): ...@@ -1000,7 +1006,8 @@ class VM_Linker(link.LocalLinker):
nodes, thunks, pre_call_clear, nodes, thunks, pre_call_clear,
storage_map, compute_map, storage_map, compute_map,
self.fgraph, self.allow_gc, self.fgraph, self.allow_gc,
dependencies=deps len(updated_vars),
dependencies=deps,
) )
return vm return vm
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论