提交 b3bf808f authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2645 from DeathMonster666/master

Added ability to store results from theano function in a python dictiona...
......@@ -2181,7 +2181,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
accept_inplace=False,
function_builder=Function,
profile=None,
on_unused_input=None):
on_unused_input=None,
output_keys=None):
"""
:type inputs: a list of SymbolicInput instances
......@@ -2197,6 +2198,10 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
:param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'.
:param output_keys: If the outputs argument for theano.function was a
list, then output_keys is None. If the outputs argument was a dict,
then output_keys is a sorted list of the keys from that dict.
:note: this function sets TensorType.filter_checks_isfinite
when `mode.check_isfinite` is True
......@@ -2316,6 +2321,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
self.accept_inplace = accept_inplace
self.function_builder = function_builder
self.mode = mode
self.output_keys = output_keys
def create(self, defaults=None, trustme=False):
"""
......@@ -2424,7 +2430,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
_fn, _i, _o = self.linker.make_thunk(input_storage=input_storage)
fn = self.function_builder(_fn, _i, _o, self.indices,
self.outputs, defaults, self.unpack_single,
self.return_none, self)
self.return_none, self.output_keys, self)
return fn
......
......@@ -49,7 +49,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
:param inputs: function parameters, these are not allowed to be shared
variables
:type outputs: list of Variables or Out instances
:type outputs: list or dict of Variables or Out instances. If it is a
dict, the keys must be strings
:param outputs: expressions to compute
:type mode: string or `Mode` instance.
......@@ -185,6 +186,24 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
"""
if isinstance(outputs, dict):
output_items = outputs.items()
for item_pair in output_items:
assert isinstance(item_pair[0], basestring)
output_items_sorted = sorted(output_items)
output_keys = []
outputs = []
for pair in output_items_sorted:
output_keys.append(pair[0])
outputs.append(pair[1])
else:
output_keys = None
if name is None:
# Determine possible file names
source_file = re.sub('\.pyc?', '.py', __file__)
......@@ -263,7 +282,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
rebuild_strict=rebuild_strict,
allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input,
profile=profile)
profile=profile,
output_keys=output_keys)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs
......
......@@ -9,7 +9,6 @@ import cPickle
import itertools
import time
import warnings
import numpy
import theano
......@@ -282,7 +281,7 @@ class Function(object):
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, maker):
defaults, unpack_single, return_none, output_keys, maker):
"""
Initialize attributes. create finder, inv_finder.
"""
......@@ -300,6 +299,7 @@ class Function(object):
self.trust_input = False # If True, we don't check the input parameter
self.name = None
self.nodes_with_inner_function = []
self.output_keys = output_keys
# We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage)
......@@ -678,6 +678,13 @@ class Function(object):
elif self.unpack_single and len(outputs) == 1:
return outputs[0]
else:
if self.output_keys is not None:
assert len(self.output_keys) == len(outputs)
return dict(itertools.izip(self.output_keys, outputs))
return outputs
value = property(
......@@ -1049,7 +1056,8 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None):
profile=None, on_unused_input=None, fgraph=None,
output_keys=None):
"""
:type inputs: a list of SymbolicInput instances
......@@ -1203,6 +1211,7 @@ class FunctionMaker(object):
self.accept_inplace = accept_inplace
self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used only for the pickling
self.output_keys = output_keys
self.required = [(i.value is None) for i in self.inputs]
self.refeed = [
......@@ -1338,7 +1347,7 @@ class FunctionMaker(object):
self.profile.import_time += import_time
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, self.return_none, self)
defaults, self.unpack_single, self.return_none, self.output_keys, self)
fn.profile = self.profile
return fn
......@@ -1398,7 +1407,8 @@ def register_checker(checker):
def orig_function(inputs, outputs, mode=None, accept_inplace=False,
name=None, profile=None, on_unused_input=None):
name=None, profile=None, on_unused_input=None,
output_keys=None):
"""
Return a Function that will calculate the outputs from the inputs.
......@@ -1434,6 +1444,11 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
:param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', 'ignore'
and None
:param output_keys: If the outputs were provided to theano.function as a
list, then output_keys is None. Otherwise, if outputs were provided
as a dict, output_keys is the sorted list of keys from the outputs
"""
# Every element of the input list will be upgraded to an `In` instance if
......@@ -1464,7 +1479,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
mode,
accept_inplace=accept_inplace,
profile=profile,
on_unused_input=on_unused_input).create(
on_unused_input=on_unused_input,
output_keys = output_keys).create(
defaults)
t2 = time.time()
......
......@@ -336,7 +336,7 @@ class Param(object):
def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None):
profile=None, on_unused_input=None,output_keys=None):
"""Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances.
......@@ -508,7 +508,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile,
on_unused_input=on_unused_input)
on_unused_input=on_unused_input, output_keys=output_keys)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
import unittest
import theano
import theano.tensor as T
class dictionary_output_checker(unittest.TestCase):
def test_output_dictionary(self):
'''
Tests that theano.function works when outputs is a dictionary
'''
x = T.scalar()
f = theano.function([x], outputs={'a': x, 'c': x*2,
'b': x*3, '1': x*4})
outputs = f(10.0)
assert outputs['a'] == 10.0
assert outputs['b'] == 30.0
assert outputs['1'] == 40.0
assert outputs['c'] == 20.0
def test_input_named_variables(self):
'''
Tests that named variables work when outputs is a dictionary
'''
x = T.scalar('x')
y = T.scalar('y')
f = theano.function([x, y], outputs={'a': x + y, 'b': x * y})
assert f(2, 4) == {'a': 6, 'b': 8}
assert f(2, y=4) == f(2, 4)
assert f(x=2, y=4) == f(2, 4)
def test_output_order_sorted(self):
'''
Tests that the output keys are sorted correctly.
'''
x = T.scalar('x')
y = T.scalar('y')
z = T.scalar('z')
e1 = T.scalar('1')
e2 = T.scalar('2')
f = theano.function([x, y, z, e1, e2], outputs={'x': x, 'y': y, 'z': z,
'1': e1, '2': e2})
assert '1' in str(f.outputs[0])
assert '2' in str(f.outputs[1])
assert 'x' in str(f.outputs[2])
assert 'y' in str(f.outputs[3])
assert 'z' in str(f.outputs[4])
def test_composing_function(self):
'''
Tests that one can compose two theano functions when the outputs are
provided in a dictionary.
'''
x = T.scalar('x')
y = T.scalar('y')
a = x + y
b = x * y
f = theano.function([x, y], outputs={'a': a, 'b': b})
a = T.scalar('a')
b = T.scalar('b')
l = a + b
r = a * b
g = theano.function([a, b], outputs=[l, r])
result = g(**f(5, 7))
assert result[0] == 47.0
assert result[1] == 420.0
def test_output_list_still_works(self):
'''
Test that theano.function works if outputs is a list.
'''
x = T.scalar('x')
f = theano.function([x], outputs=[x * 3, x * 2, x * 4, x])
result = f(5.0)
assert result[0] == 15.0
assert result[1] == 10.0
assert result[2] == 20.0
assert result[3] == 5.0
def test_debug_mode_dict(self):
'''
Tests that debug mode works where outputs is a dictionary.
'''
x = T.scalar('x')
f = theano.function([x], outputs={'1': x, '2': 2 * x,
'3': 3 * x}, mode="DEBUG_MODE")
result = f(3.0)
assert result['1'] == 3.0
assert result['2'] == 6.0
assert result['3'] == 9.0
def test_debug_mode_list(self):
'''
Tests that debug mode works where the outputs argument is a list.
'''
x = T.scalar('x')
f = theano.function([x], outputs=[x, 2 * x, 3 * x], mode="DEBUG_MODE")
result = f(5.0)
assert result[0] == 5.0
assert result[1] == 10.0
assert result[2] == 15.0
def test_key_string_requirement(self):
'''
Tests that an exception is thrown if a non-string key is used in
the outputs dictionary.
'''
x = T.scalar('x')
try:
theano.function([x], outputs={1.0: x})
raise Exception("Did not throw exception with 1.0 as only key")
except AssertionError:
pass
try:
theano.function([x], outputs={1.0: x, "a": x**2})
raise Exception("Did not throw exception with 1.0 as one key")
except AssertionError:
pass
try:
theano.function([x], outputs={(1, "b"): x, 1.0: x**2})
raise Exception("Did not throw exception with tuple as key")
except AssertionError:
pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论