提交 265311c0 authored 作者: AlexLamb's avatar AlexLamb

Refactored dictionary wrapper into Function.call

上级 a1f266b1
......@@ -211,7 +211,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
else:
outputs_dict_format = False
output_keys = None
if name is None:
# Determine possible file names
......@@ -291,17 +291,12 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
rebuild_strict=rebuild_strict,
allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input,
profile=profile)
profile=profile,output_dictionary_flag=outputs_dict_format,
output_keys=output_keys)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs
if outputs_dict_format:
fn_dict_output = (lambda *args, **kwargs: output_dictionary_wrapper(args, kwargs, fn = fn, keys = output_keys))
return fn_dict_output
return fn
......
......@@ -280,7 +280,7 @@ class Function(object):
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, maker):
defaults, unpack_single, return_none, output_dictionary_flag, output_keys, maker):
"""
Initialize attributes. create finder, inv_finder.
"""
......@@ -298,6 +298,8 @@ class Function(object):
self.trust_input = False # If True, we don't check the input parameter
self.name = None
self.nodes_with_inner_function = []
self.output_dictionary_flag = output_dictionary_flag
self.output_keys = output_keys
# We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage)
......@@ -671,11 +673,26 @@ class Function(object):
if hasattr(self.fn, 'update_profile'):
self.fn.update_profile(profile)
if self.return_none:
return None
elif self.unpack_single and len(outputs) == 1:
return outputs[0]
else:
if self.output_dictionary_flag:
outputDict = {}
assert len(self.output_keys) == len(outputs)
for i in range(0, len(self.output_keys)):
outputDict[self.output_keys[i]] = outputs[i]
outputs = outputDict
return outputs
value = property(
......@@ -1047,7 +1064,8 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None):
profile=None, on_unused_input=None, fgraph=None,
output_dictionary_flag=False,output_keys=None):
"""
:type inputs: a list of SymbolicInput instances
......@@ -1201,6 +1219,8 @@ class FunctionMaker(object):
self.accept_inplace = accept_inplace
self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used only for the pickling
self.output_dictionary_flag = output_dictionary_flag
self.output_keys = output_keys
self.required = [(i.value is None) for i in self.inputs]
self.refeed = [
......@@ -1336,7 +1356,7 @@ class FunctionMaker(object):
self.profile.import_time += import_time
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, self.return_none, self)
defaults, self.unpack_single, self.return_none, self.output_dictionary_flag, self.output_keys, self)
fn.profile = self.profile
return fn
......@@ -1396,7 +1416,8 @@ def register_checker(checker):
def orig_function(inputs, outputs, mode=None, accept_inplace=False,
name=None, profile=None, on_unused_input=None):
name=None, profile=None, on_unused_input=None,
output_dictionary_flag=False,output_keys=None):
"""
Return a Function that will calculate the outputs from the inputs.
......@@ -1462,7 +1483,9 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
mode,
accept_inplace=accept_inplace,
profile=profile,
on_unused_input=on_unused_input).create(
on_unused_input=on_unused_input,
output_dictionary_flag = output_dictionary_flag,
output_keys = output_keys).create(
defaults)
t2 = time.time()
......
......@@ -333,7 +333,7 @@ class Param(object):
self.implicit = implicit
def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
def pfunc(params, output_dictionary_flag,output_keys, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None):
......@@ -506,7 +506,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
mutable=False, borrow=True, shared=True)
inputs.append(si)
return orig_function(inputs, cloned_outputs, mode,
return orig_function(inputs, cloned_outputs, mode, output_dictionary_flag=output_dictionary_flag,output_keys = output_keys,
accept_inplace=accept_inplace, name=name, profile=profile,
on_unused_input=on_unused_input)
......
......@@ -16,6 +16,9 @@ try:
except ImportError:
flake8_available = False
print theano.__file__
print "fake8available", flake8_available
whitelist_flake8 = [
"updates.py",
"printing.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论