提交 c4d1145b authored 作者: Jakub Sygnowski's avatar Jakub Sygnowski

handling output_keys in partial function evaluation

上级 c54cfde7
......@@ -764,19 +764,25 @@ class Function(object):
kwargs : dict
TODO: other kwargs?
Keyword argument `output_subset` is a list of indices of the
function's outputs that are requested to be calculated.
Keyword argument `output_subset` is a list of either indices of the
function's outputs or the keys belonging to the `output_keys` dict
and represent outputs that are requested to be calculated.
Returns
-------
List of outputs on indices `output_subset` or all of them, if
`outpus_subset` is not passed. If there's only one output, returns just
the value.
List of outputs on indices/keys from `output_subset` or all of them, if
`outputs_subset` is not passed.
"""
profile = self.profile
t0 = time.time()
output_subset = kwargs.pop('output_subset', None)
if output_subset is not None and self.output_keys is not None:
if not all(key in self.output_keys for key in output_subset):
raise KeyError('Output_subset should be a list of keys '
'from output_keys')
output_subset =\
[self.output_keys.index(key) for key in output_subset]
# Reinitialize each container's 'provided' counter
if self.trust_input:
......@@ -879,7 +885,9 @@ class Function(object):
# Do the actual work
t0_fn = time.time()
try:
outputs = self.fn(output_subset=output_subset)
outputs =\
self.fn() if output_subset is None else\
self.fn(output_subset=output_subset)
except Exception:
if hasattr(self.fn, 'position_of_error'):
# this is a new vm-provided function or c linker
......@@ -963,11 +971,18 @@ class Function(object):
assert len(self.output_keys) == len(outputs)
return dict(izip(self.output_keys, outputs))
if output_subset is None:
return dict(izip(self.output_keys, outputs))
else:
return dict((self.output_keys[index], outputs[index])
for index in output_subset)
if output_subset is None:
return outputs
else:
if not all(isinstance(index, int) for index in output_subset):
raise TypeError('Output_subset should be '
'a list of indices of output variables')
return [outputs[i] for i in output_subset]
value = property(
......
......@@ -196,6 +196,7 @@ def test_speed_lazy():
def test_partial_function():
import numpy as np
from theano.tests import unittest_tools as utt
x = tensor.scalar('input')
y = x ** 2
f = theano.function([x], [y + 7, y - 9, y / 14.], mode=Mode(
......@@ -203,7 +204,16 @@ def test_partial_function():
assert f(3, output_subset=[0, 1, 2]) == f(3)
assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]]
assert abs(max(f(5) - np.array([32., 16., 1.7857142857142858]))) < 1e-9
utt.assert_allclose(f(5), np.array([32., 16., 1.7857142857142858]))
def test_partial_function_output_keys():
x = tensor.scalar('input')
y = 3 * x
f = theano.function([x], {'a': y * 5, 'b': y - 7}, mode=Mode(
optimizer=None, linker=vm.VM_Linker(allow_partial_eval=True)))
assert f(5, output_subset=['a'])['a'] == f(5)['a']
def test_allow_gc_cvm():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论