提交 c4d1145b authored 作者: Jakub Sygnowski's avatar Jakub Sygnowski

handling output_keys in partial function evaluation

上级 c54cfde7
...@@ -764,19 +764,25 @@ class Function(object): ...@@ -764,19 +764,25 @@ class Function(object):
kwargs : dict kwargs : dict
TODO: other kwargs? TODO: other kwargs?
Keyword argument `output_subset` is a list of indices of the Keyword argument `output_subset` is a list of either indices of the
function's outputs that are requested to be calculated. function's outputs or the keys belonging to the `output_keys` dict
and represent outputs that are requested to be calculated.
Returns Returns
------- -------
List of outputs on indices `output_subset` or all of them, if List of outputs on indices/keys from `output_subset` or all of them, if
`outpus_subset` is not passed. If there's only one output, returns just `outputs_subset` is not passed.
the value.
""" """
profile = self.profile profile = self.profile
t0 = time.time() t0 = time.time()
output_subset = kwargs.pop('output_subset', None) output_subset = kwargs.pop('output_subset', None)
if output_subset is not None and self.output_keys is not None:
if not all(key in self.output_keys for key in output_subset):
raise KeyError('Output_subset should be a list of keys '
'from output_keys')
output_subset =\
[self.output_keys.index(key) for key in output_subset]
# Reinitialize each container's 'provided' counter # Reinitialize each container's 'provided' counter
if self.trust_input: if self.trust_input:
...@@ -879,7 +885,9 @@ class Function(object): ...@@ -879,7 +885,9 @@ class Function(object):
# Do the actual work # Do the actual work
t0_fn = time.time() t0_fn = time.time()
try: try:
outputs = self.fn(output_subset=output_subset) outputs =\
self.fn() if output_subset is None else\
self.fn(output_subset=output_subset)
except Exception: except Exception:
if hasattr(self.fn, 'position_of_error'): if hasattr(self.fn, 'position_of_error'):
# this is a new vm-provided function or c linker # this is a new vm-provided function or c linker
...@@ -963,11 +971,18 @@ class Function(object): ...@@ -963,11 +971,18 @@ class Function(object):
assert len(self.output_keys) == len(outputs) assert len(self.output_keys) == len(outputs)
return dict(izip(self.output_keys, outputs)) if output_subset is None:
return dict(izip(self.output_keys, outputs))
else:
return dict((self.output_keys[index], outputs[index])
for index in output_subset)
if output_subset is None: if output_subset is None:
return outputs return outputs
else: else:
if not all(isinstance(index, int) for index in output_subset):
raise TypeError('Output_subset should be '
'a list of indices of output variables')
return [outputs[i] for i in output_subset] return [outputs[i] for i in output_subset]
value = property( value = property(
......
...@@ -196,6 +196,7 @@ def test_speed_lazy(): ...@@ -196,6 +196,7 @@ def test_speed_lazy():
def test_partial_function(): def test_partial_function():
import numpy as np import numpy as np
from theano.tests import unittest_tools as utt
x = tensor.scalar('input') x = tensor.scalar('input')
y = x ** 2 y = x ** 2
f = theano.function([x], [y + 7, y - 9, y / 14.], mode=Mode( f = theano.function([x], [y + 7, y - 9, y / 14.], mode=Mode(
...@@ -203,7 +204,16 @@ def test_partial_function(): ...@@ -203,7 +204,16 @@ def test_partial_function():
assert f(3, output_subset=[0, 1, 2]) == f(3) assert f(3, output_subset=[0, 1, 2]) == f(3)
assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]] assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]]
assert abs(max(f(5) - np.array([32., 16., 1.7857142857142858]))) < 1e-9 utt.assert_allclose(f(5), np.array([32., 16., 1.7857142857142858]))
def test_partial_function_output_keys():
x = tensor.scalar('input')
y = 3 * x
f = theano.function([x], {'a': y * 5, 'b': y - 7}, mode=Mode(
optimizer=None, linker=vm.VM_Linker(allow_partial_eval=True)))
assert f(5, output_subset=['a'])['a'] == f(5)['a']
def test_allow_gc_cvm(): def test_allow_gc_cvm():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论