提交 b3bf808f authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2645 from DeathMonster666/master

Added ability to store results from theano function in a python dictiona...
...@@ -2181,7 +2181,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2181,7 +2181,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
accept_inplace=False, accept_inplace=False,
function_builder=Function, function_builder=Function,
profile=None, profile=None,
on_unused_input=None): on_unused_input=None,
output_keys=None):
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
...@@ -2197,6 +2198,10 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2197,6 +2198,10 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
:param on_unused_input: What to do if a variable in the 'inputs' list is :param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'. not used in the graph. Possible values are 'raise', 'warn', and 'ignore'.
:param output_keys: If the outputs argument for theano.function was a
list, then output_keys is None. If the outputs argument was a dict,
then output_keys is a sorted list of the keys from that dict.
:note: this function sets TensorType.filter_checks_isfinite :note: this function sets TensorType.filter_checks_isfinite
when `mode.check_isfinite` is True when `mode.check_isfinite` is True
...@@ -2316,6 +2321,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2316,6 +2321,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
self.accept_inplace = accept_inplace self.accept_inplace = accept_inplace
self.function_builder = function_builder self.function_builder = function_builder
self.mode = mode self.mode = mode
self.output_keys = output_keys
def create(self, defaults=None, trustme=False): def create(self, defaults=None, trustme=False):
""" """
...@@ -2424,7 +2430,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2424,7 +2430,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
_fn, _i, _o = self.linker.make_thunk(input_storage=input_storage) _fn, _i, _o = self.linker.make_thunk(input_storage=input_storage)
fn = self.function_builder(_fn, _i, _o, self.indices, fn = self.function_builder(_fn, _i, _o, self.indices,
self.outputs, defaults, self.unpack_single, self.outputs, defaults, self.unpack_single,
self.return_none, self) self.return_none, self.output_keys, self)
return fn return fn
......
...@@ -49,7 +49,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -49,7 +49,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
:param inputs: function parameters, these are not allowed to be shared :param inputs: function parameters, these are not allowed to be shared
variables variables
:type outputs: list of Variables or Out instances :type outputs: list or dict of Variables or Out instances. If it is a
dict, the keys must be strings
:param outputs: expressions to compute :param outputs: expressions to compute
:type mode: string or `Mode` instance. :type mode: string or `Mode` instance.
...@@ -185,6 +186,24 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -185,6 +186,24 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
""" """
if isinstance(outputs, dict):
output_items = outputs.items()
for item_pair in output_items:
assert isinstance(item_pair[0], basestring)
output_items_sorted = sorted(output_items)
output_keys = []
outputs = []
for pair in output_items_sorted:
output_keys.append(pair[0])
outputs.append(pair[1])
else:
output_keys = None
if name is None: if name is None:
# Determine possible file names # Determine possible file names
source_file = re.sub('\.pyc?', '.py', __file__) source_file = re.sub('\.pyc?', '.py', __file__)
...@@ -263,7 +282,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -263,7 +282,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
rebuild_strict=rebuild_strict, rebuild_strict=rebuild_strict,
allow_input_downcast=allow_input_downcast, allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
profile=profile) profile=profile,
output_keys=output_keys)
# We need to add the flag check_aliased inputs if we have any mutable or # We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs # borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs fn._check_for_aliased_inputs = check_for_aliased_inputs
......
...@@ -9,7 +9,6 @@ import cPickle ...@@ -9,7 +9,6 @@ import cPickle
import itertools import itertools
import time import time
import warnings import warnings
import numpy import numpy
import theano import theano
...@@ -282,7 +281,7 @@ class Function(object): ...@@ -282,7 +281,7 @@ class Function(object):
""" """
def __init__(self, fn, input_storage, output_storage, indices, outputs, def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, maker): defaults, unpack_single, return_none, output_keys, maker):
""" """
Initialize attributes. create finder, inv_finder. Initialize attributes. create finder, inv_finder.
""" """
...@@ -300,6 +299,7 @@ class Function(object): ...@@ -300,6 +299,7 @@ class Function(object):
self.trust_input = False # If True, we don't check the input parameter self.trust_input = False # If True, we don't check the input parameter
self.name = None self.name = None
self.nodes_with_inner_function = [] self.nodes_with_inner_function = []
self.output_keys = output_keys
# We will be popping stuff off this `containers` object. It is a copy. # We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage) containers = list(self.input_storage)
...@@ -678,6 +678,13 @@ class Function(object): ...@@ -678,6 +678,13 @@ class Function(object):
elif self.unpack_single and len(outputs) == 1: elif self.unpack_single and len(outputs) == 1:
return outputs[0] return outputs[0]
else: else:
if self.output_keys is not None:
assert len(self.output_keys) == len(outputs)
return dict(itertools.izip(self.output_keys, outputs))
return outputs return outputs
value = property( value = property(
...@@ -1049,7 +1056,8 @@ class FunctionMaker(object): ...@@ -1049,7 +1056,8 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs, def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function, mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None): profile=None, on_unused_input=None, fgraph=None,
output_keys=None):
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
...@@ -1203,6 +1211,7 @@ class FunctionMaker(object): ...@@ -1203,6 +1211,7 @@ class FunctionMaker(object):
self.accept_inplace = accept_inplace self.accept_inplace = accept_inplace
self.function_builder = function_builder self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used only for the pickling self.on_unused_input = on_unused_input # Used only for the pickling
self.output_keys = output_keys
self.required = [(i.value is None) for i in self.inputs] self.required = [(i.value is None) for i in self.inputs]
self.refeed = [ self.refeed = [
...@@ -1338,7 +1347,7 @@ class FunctionMaker(object): ...@@ -1338,7 +1347,7 @@ class FunctionMaker(object):
self.profile.import_time += import_time self.profile.import_time += import_time
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, self.return_none, self) defaults, self.unpack_single, self.return_none, self.output_keys, self)
fn.profile = self.profile fn.profile = self.profile
return fn return fn
...@@ -1398,7 +1407,8 @@ def register_checker(checker): ...@@ -1398,7 +1407,8 @@ def register_checker(checker):
def orig_function(inputs, outputs, mode=None, accept_inplace=False, def orig_function(inputs, outputs, mode=None, accept_inplace=False,
name=None, profile=None, on_unused_input=None): name=None, profile=None, on_unused_input=None,
output_keys=None):
""" """
Return a Function that will calculate the outputs from the inputs. Return a Function that will calculate the outputs from the inputs.
...@@ -1434,6 +1444,11 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1434,6 +1444,11 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
:param on_unused_input: What to do if a variable in the 'inputs' list is :param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', 'ignore' not used in the graph. Possible values are 'raise', 'warn', 'ignore'
and None and None
:param output_keys: If the outputs were provided to theano.function as a
list, then output_keys is None. Otherwise, if outputs were provided
as a dict, output_keys is the sorted list of keys from the outputs
""" """
# Every element of the input list will be upgraded to an `In` instance if # Every element of the input list will be upgraded to an `In` instance if
...@@ -1464,7 +1479,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1464,7 +1479,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
mode, mode,
accept_inplace=accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input).create( on_unused_input=on_unused_input,
output_keys = output_keys).create(
defaults) defaults)
t2 = time.time() t2 = time.time()
......
...@@ -336,7 +336,7 @@ class Param(object): ...@@ -336,7 +336,7 @@ class Param(object):
def pfunc(params, outputs=None, mode=None, updates=None, givens=None, def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None): profile=None, on_unused_input=None,output_keys=None):
"""Function-constructor for graphs with shared variables. """Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances. :type params: list of either Variable or Param instances.
...@@ -508,7 +508,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -508,7 +508,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
return orig_function(inputs, cloned_outputs, mode, return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile, accept_inplace=accept_inplace, name=name, profile=profile,
on_unused_input=on_unused_input) on_unused_input=on_unused_input, output_keys=output_keys)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
import unittest
import theano
import theano.tensor as T
class dictionary_output_checker(unittest.TestCase):
def test_output_dictionary(self):
'''
Tests that theano.function works when outputs is a dictionary
'''
x = T.scalar()
f = theano.function([x], outputs={'a': x, 'c': x*2,
'b': x*3, '1': x*4})
outputs = f(10.0)
assert outputs['a'] == 10.0
assert outputs['b'] == 30.0
assert outputs['1'] == 40.0
assert outputs['c'] == 20.0
def test_input_named_variables(self):
'''
Tests that named variables work when outputs is a dictionary
'''
x = T.scalar('x')
y = T.scalar('y')
f = theano.function([x, y], outputs={'a': x + y, 'b': x * y})
assert f(2, 4) == {'a': 6, 'b': 8}
assert f(2, y=4) == f(2, 4)
assert f(x=2, y=4) == f(2, 4)
def test_output_order_sorted(self):
'''
Tests that the output keys are sorted correctly.
'''
x = T.scalar('x')
y = T.scalar('y')
z = T.scalar('z')
e1 = T.scalar('1')
e2 = T.scalar('2')
f = theano.function([x, y, z, e1, e2], outputs={'x': x, 'y': y, 'z': z,
'1': e1, '2': e2})
assert '1' in str(f.outputs[0])
assert '2' in str(f.outputs[1])
assert 'x' in str(f.outputs[2])
assert 'y' in str(f.outputs[3])
assert 'z' in str(f.outputs[4])
def test_composing_function(self):
'''
Tests that one can compose two theano functions when the outputs are
provided in a dictionary.
'''
x = T.scalar('x')
y = T.scalar('y')
a = x + y
b = x * y
f = theano.function([x, y], outputs={'a': a, 'b': b})
a = T.scalar('a')
b = T.scalar('b')
l = a + b
r = a * b
g = theano.function([a, b], outputs=[l, r])
result = g(**f(5, 7))
assert result[0] == 47.0
assert result[1] == 420.0
def test_output_list_still_works(self):
'''
Test that theano.function works if outputs is a list.
'''
x = T.scalar('x')
f = theano.function([x], outputs=[x * 3, x * 2, x * 4, x])
result = f(5.0)
assert result[0] == 15.0
assert result[1] == 10.0
assert result[2] == 20.0
assert result[3] == 5.0
def test_debug_mode_dict(self):
'''
Tests that debug mode works where outputs is a dictionary.
'''
x = T.scalar('x')
f = theano.function([x], outputs={'1': x, '2': 2 * x,
'3': 3 * x}, mode="DEBUG_MODE")
result = f(3.0)
assert result['1'] == 3.0
assert result['2'] == 6.0
assert result['3'] == 9.0
def test_debug_mode_list(self):
'''
Tests that debug mode works where the outputs argument is a list.
'''
x = T.scalar('x')
f = theano.function([x], outputs=[x, 2 * x, 3 * x], mode="DEBUG_MODE")
result = f(5.0)
assert result[0] == 5.0
assert result[1] == 10.0
assert result[2] == 15.0
def test_key_string_requirement(self):
'''
Tests that an exception is thrown if a non-string key is used in
the outputs dictionary.
'''
x = T.scalar('x')
try:
theano.function([x], outputs={1.0: x})
raise Exception("Did not throw exception with 1.0 as only key")
except AssertionError:
pass
try:
theano.function([x], outputs={1.0: x, "a": x**2})
raise Exception("Did not throw exception with 1.0 as one key")
except AssertionError:
pass
try:
theano.function([x], outputs={(1, "b"): x, 1.0: x**2})
raise Exception("Did not throw exception with tuple as key")
except AssertionError:
pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论