提交 73620123 authored 作者: Jakub Sygnowski's avatar Jakub Sygnowski

handling updates with partial evaluation

上级 44e5c993
......@@ -865,9 +865,10 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
}
}
int first_updated = self->n_output_vars - self->n_updates;
for (int i = 0; i < self->n_output_vars && (!err); ++i)
{
if (output_subset == NULL || output_subset[i] == 1)
if (i >= first_updated || output_subset == NULL || output_subset[i] == 1)
{
err = lazy_rec_eval(self, self->output_vars[i], one, zero);
}
......@@ -877,7 +878,6 @@ CLazyLinker_call(PyObject *_self, PyObject *args, PyObject *kwds)
{
// save references to outputs prior to updating storage containers
assert (self->n_output_vars >= self->n_updates);
assert (output_subset_size < 0 || output_subset_size >= self->n_updates);
Py_DECREF(rval);
rval = PyList_New(self->n_output_vars);
for (int i = 0; i < (self->n_output_vars); ++i)
......
......@@ -212,6 +212,7 @@ def test_partial_function():
check_partial_function('cvm')
# TODO: implement output_keys with CVM
def test_partial_function_output_keys():
x = tensor.scalar('input')
y = 3 * x
......@@ -221,6 +222,25 @@ def test_partial_function_output_keys():
assert f(5, output_subset=['a'])['a'] == f(5)['a']
def test_partial_function_with_updates():
def check_updates(linker_name):
x = tensor.lscalar('input')
y = theano.shared(1, name='global')
f = theano.function([x], [x, x + 34], updates=[(y, x + 1)], mode=Mode(
optimizer=None, linker=linker_name))
g = theano.function([x], [x - 6], updates=[(y, y + 3)], mode=Mode(
optimizer=None, linker=linker_name))
f(3, output_subset=[])
assert(y.get_value() == 4)
assert(g(30, output_subset=[0]) == [24])
g(40, output_subset=[])
assert(y.get_value() == 10)
check_updates(vm.VM_Linker(allow_partial_eval=True))
check_updates('cvm')
def test_allow_gc_cvm():
mode = theano.config.mode
if mode in ['DEBUG_MODE', 'DebugMode']:
......
......@@ -332,7 +332,8 @@ class Stack(VM):
def __init__(self, nodes, thunks, pre_call_clear,
storage_map, compute_map, fgraph, allow_gc,
dependencies=None, callback=None, callback_input=None):
dependencies=None, callback=None, callback_input=None,
n_updates=0):
super(Stack, self).__init__(nodes, thunks, pre_call_clear)
self.allow_gc = allow_gc
......@@ -346,6 +347,7 @@ class Stack(VM):
self.node_idx = node_idx = {}
self.callback = callback
self.callback_input = callback_input
self.n_updates = n_updates
ords = fgraph.orderings()
......@@ -417,6 +419,9 @@ class Stack(VM):
# apply_stack contains nodes
if output_subset is not None:
first_updated = len(self.outputs) - self.n_updates
output_subset = output_subset + range(first_updated,
len(self.outputs))
apply_stack =\
[self.outputs[i].owner for i in output_subset
if self.outputs[i].owner]
......@@ -425,7 +430,7 @@ class Stack(VM):
last_apply_stack_len = -1
# This record all function inputs/shared varibles and constants
# This record all function inputs/shared variables and constants
for var, data in iteritems(self.storage_map):
if data[0] is None:
continue
......@@ -852,7 +857,7 @@ class VM_Linker(link.LocalLinker):
'CVM does not support memory profile, using Stack VM.')
if not self.use_cloop and self.allow_partial_eval:
warnings.warn(
'LoopGCdoes not support partial evaluation, '
'LoopGC does not support partial evaluation, '
'using Stack VM.')
# Needed for allow_gc=True, profiling and storage_map reuse
deps = self.compute_gc_dependencies(storage_map)
......@@ -862,7 +867,8 @@ class VM_Linker(link.LocalLinker):
self.fgraph, self.allow_gc,
dependencies=deps,
callback=self.callback,
callback_input=self.callback_input)
callback_input=self.callback_input,
n_updates=len(updated_vars))
elif self.use_cloop:
# create a map from nodes to ints and vars to ints
nodes_idx = {}
......@@ -1000,7 +1006,8 @@ class VM_Linker(link.LocalLinker):
nodes, thunks, pre_call_clear,
storage_map, compute_map,
self.fgraph, self.allow_gc,
dependencies=deps
dependencies=deps,
n_updates=len(updated_vars)
)
return vm
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论