提交 b7f377f3 authored 作者: AlexLamb's avatar AlexLamb

Moved everything into function.py. Added sorting on output variables.

上级 814b92b8
'''
Generates a wrapper around theano functions that allows the user to receive outputs in a dictionary.
'''
def createFunctionReturningDictionary(args, kwargs, fn, keys):
outputLst = fn(*args, **kwargs)
outputDict = {}
for i in range(0, len(keys)):
outputDict[keys[i]] = outputLst[i]
return outputDict
...@@ -17,7 +17,6 @@ import warnings ...@@ -17,7 +17,6 @@ import warnings
from theano import gof from theano import gof
from theano import compat from theano import compat
from theano.compile.dictionaryOutputWrapper import createFunctionReturningDictionary
def function_dump(filename, inputs, outputs=None, mode=None, updates=None, def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
givens=None, givens=None,
...@@ -38,6 +37,18 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None, ...@@ -38,6 +37,18 @@ def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
cPickle.dump(d, f, -1) cPickle.dump(d, f, -1)
def output_dictionary_wrapper(args, kwargs, fn, keys):
outputLst = fn(*args, **kwargs)
outputDict = {}
for i in range(0, len(keys)):
outputDict[keys[i]] = outputLst[i]
return outputDict
def function(inputs, outputs=None, mode=None, updates=None, givens=None, def function(inputs, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
...@@ -189,8 +200,17 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -189,8 +200,17 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if type(outputs) is dict: if type(outputs) is dict:
outputsDictFormat = True outputsDictFormat = True
outputKeys = outputs.keys() outputItems = outputs.items()
outputs = outputs.values()
outputItemsSorted = sorted(outputItems, key = lambda x: x[0])
outputKeys = []
outputs = []
for pair in outputItemsSorted:
outputKeys.append(pair[0])
outputs.append(pair[1])
else: else:
outputsDictFormat = False outputsDictFormat = False
...@@ -280,7 +300,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -280,7 +300,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if outputsDictFormat: if outputsDictFormat:
fnDictOutput = (lambda *args, **kwargs: createFunctionReturningDictionary(args, kwargs, fn = fn, keys = outputKeys)) fnDictOutput = (lambda *args, **kwargs: output_dictionary_wrapper(args, kwargs, fn = fn, keys = outputKeys))
return fnDictOutput return fnDictOutput
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论