提交 47367022 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4279 from sygi/partial-evaluation

Partial function evaluation
......@@ -214,4 +214,4 @@ Reference
.. autofunction:: theano.compile.function.function_dump
.. autoclass:: theano.compile.function_module.Function
:members: free, copy
\ No newline at end of file
:members: free, copy, __call__
......@@ -586,7 +586,8 @@ class Function(object):
Returns
-------
Copied theano.Function
theano.Function
Copied theano.Function
"""
# helper function
def checkSV(sv_ori, sv_rpl):
......@@ -752,9 +753,37 @@ class Function(object):
return f_cpy
def __call__(self, *args, **kwargs):
"""
Evaluates value of a function on given arguments.
Parameters
----------
args : list
List of inputs to the function. All inputs are required, even when
some of them are not necessary to calculate requested subset of
outputs.
kwargs : dict
The function inputs can be passed as keyword argument. For this, use
the name of the input or the input instance as the key.
Keyword argument ``output_subset`` is a list of either indices of the
function's outputs or the keys belonging to the `output_keys` dict
and represent outputs that are requested to be calculated.
Returns
-------
list
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
"""
profile = self.profile
t0 = time.time()
output_subset = kwargs.pop('output_subset', None)
if output_subset is not None and self.output_keys is not None:
output_subset =\
[self.output_keys.index(key) for key in output_subset]
# Reinitialize each container's 'provided' counter
if self.trust_input:
i = 0
......@@ -856,7 +885,9 @@ class Function(object):
# Do the actual work
t0_fn = time.time()
try:
outputs = self.fn()
outputs =\
self.fn() if output_subset is None else\
self.fn(output_subset=output_subset)
except Exception:
if hasattr(self.fn, 'position_of_error'):
# this is a new vm-provided function or c linker
......@@ -933,7 +964,8 @@ class Function(object):
profile.ignore_first_call = False
if self.return_none:
return None
elif self.unpack_single and len(outputs) == 1:
elif self.unpack_single and len(outputs) == 1 and\
output_subset is None:
return outputs[0]
else:
......@@ -941,9 +973,16 @@ class Function(object):
assert len(self.output_keys) == len(outputs)
return dict(izip(self.output_keys, outputs))
if output_subset is None:
return dict(izip(self.output_keys, outputs))
else:
return dict((self.output_keys[index], outputs[index])
for index in output_subset)
return outputs
if output_subset is None:
return outputs
else:
return [outputs[i] for i in output_subset]
value = property(
lambda self: self._value,
......
......@@ -194,6 +194,28 @@ def test_speed_lazy():
use_cloop=True))
def test_partial_function():
import numpy as np
from theano.tests import unittest_tools as utt
x = tensor.scalar('input')
y = x ** 2
f = theano.function([x], [y + 7, y - 9, y / 14.], mode=Mode(
optimizer=None, linker=vm.VM_Linker(allow_partial_eval=True)))
assert f(3, output_subset=[0, 1, 2]) == f(3)
assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]]
utt.assert_allclose(f(5), np.array([32., 16., 1.7857142857142858]))
def test_partial_function_output_keys():
x = tensor.scalar('input')
y = 3 * x
f = theano.function([x], {'a': y * 5, 'b': y - 7}, mode=Mode(
optimizer=None, linker=vm.VM_Linker(allow_partial_eval=True)))
assert f(5, output_subset=['a'])['a'] == f(5)['a']
def test_allow_gc_cvm():
mode = theano.config.mode
if mode in ['DEBUG_MODE', 'DebugMode']:
......
......@@ -396,7 +396,7 @@ class Stack(VM):
)
return rval, dt
def __call__(self):
def __call__(self, output_subset=None):
storage_map = self.storage_map
compute_map = self.compute_map
thunks = self.thunks
......@@ -408,7 +408,13 @@ class Stack(VM):
compute_map[k][0] = (k.owner is None)
# apply_stack contains nodes
apply_stack = list(self.base_apply_stack)
if output_subset is not None:
apply_stack =\
[self.outputs[i].owner for i in output_subset
if self.outputs[i].owner]
else:
apply_stack = list(self.base_apply_stack)
last_apply_stack_len = -1
# This record all function inputs/shared varibles and constants
......@@ -682,11 +688,15 @@ class VM_Linker(link.LocalLinker):
c_thunks
If None or True, don't change the default. If False,
don't compile c code for the thunks.
allow_partial_eval
If True, enforces usage of Stack or CVM, to allow for partial
evaluation of functions (calculating a subset of outputs).
"""
def __init__(self, allow_gc=None, use_cloop=False, callback=None,
lazy=None, schedule=None, c_thunks=None):
lazy=None, schedule=None, c_thunks=None,
allow_partial_eval=None):
# Note: if more parameters are added to __init__, make sure to forward
# them in the "type(self)(...)" call in the "accept" method below.
if allow_gc is None:
......@@ -697,6 +707,7 @@ class VM_Linker(link.LocalLinker):
self.callback = callback
self.lazy = lazy
self.c_thunks = c_thunks
self.allow_partial_eval = allow_partial_eval
self.updated_vars = {}
if schedule:
self.schedule = schedule
......@@ -811,13 +822,18 @@ class VM_Linker(link.LocalLinker):
pre_call_clear = [storage_map[v] for v in self.no_recycling]
if (self.callback is not None or
(config.profile and config.profile_memory)):
(config.profile and config.profile_memory) or
getattr(self, 'allow_partial_eval', False)):
if self.use_cloop and self.callback is not None:
logger.warn('CVM does not support callback, using Stack VM.')
if self.use_cloop and config.profile_memory:
warnings.warn(
'CVM does not support memory profile, using Stack VM.')
if self.use_cloop and getattr(self, 'allow_partial_eval', False):
warnings.warn(
'CVM does not support partial evaluation yet, '
'using Stack VM.')
# Needed for allow_gc=True, profiling and storage_map reuse
deps = self.compute_gc_dependencies(storage_map)
vm = Stack(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论