提交 265311c0 authored 作者: AlexLamb's avatar AlexLamb

Refactored dictionary wrapper into Function.call

上级 a1f266b1
...@@ -211,7 +211,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -211,7 +211,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
else: else:
outputs_dict_format = False outputs_dict_format = False
output_keys = None
if name is None: if name is None:
# Determine possible file names # Determine possible file names
...@@ -291,17 +291,12 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -291,17 +291,12 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
rebuild_strict=rebuild_strict, rebuild_strict=rebuild_strict,
allow_input_downcast=allow_input_downcast, allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
profile=profile) profile=profile,output_dictionary_flag=outputs_dict_format,
output_keys=output_keys)
# We need to add the flag check_aliased inputs if we have any mutable or # We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs # borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs fn._check_for_aliased_inputs = check_for_aliased_inputs
if outputs_dict_format:
fn_dict_output = (lambda *args, **kwargs: output_dictionary_wrapper(args, kwargs, fn = fn, keys = output_keys))
return fn_dict_output
return fn return fn
......
...@@ -280,7 +280,7 @@ class Function(object): ...@@ -280,7 +280,7 @@ class Function(object):
""" """
def __init__(self, fn, input_storage, output_storage, indices, outputs, def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, maker): defaults, unpack_single, return_none, output_dictionary_flag, output_keys, maker):
""" """
Initialize attributes. create finder, inv_finder. Initialize attributes. create finder, inv_finder.
""" """
...@@ -298,6 +298,8 @@ class Function(object): ...@@ -298,6 +298,8 @@ class Function(object):
self.trust_input = False # If True, we don't check the input parameter self.trust_input = False # If True, we don't check the input parameter
self.name = None self.name = None
self.nodes_with_inner_function = [] self.nodes_with_inner_function = []
self.output_dictionary_flag = output_dictionary_flag
self.output_keys = output_keys
# We will be popping stuff off this `containers` object. It is a copy. # We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage) containers = list(self.input_storage)
...@@ -671,11 +673,26 @@ class Function(object): ...@@ -671,11 +673,26 @@ class Function(object):
if hasattr(self.fn, 'update_profile'): if hasattr(self.fn, 'update_profile'):
self.fn.update_profile(profile) self.fn.update_profile(profile)
if self.return_none: if self.return_none:
return None return None
elif self.unpack_single and len(outputs) == 1: elif self.unpack_single and len(outputs) == 1:
return outputs[0] return outputs[0]
else: else:
if self.output_dictionary_flag:
outputDict = {}
assert len(self.output_keys) == len(outputs)
for i in range(0, len(self.output_keys)):
outputDict[self.output_keys[i]] = outputs[i]
outputs = outputDict
return outputs return outputs
value = property( value = property(
...@@ -1047,7 +1064,8 @@ class FunctionMaker(object): ...@@ -1047,7 +1064,8 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs, def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function, mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None): profile=None, on_unused_input=None, fgraph=None,
output_dictionary_flag=False,output_keys=None):
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
...@@ -1201,6 +1219,8 @@ class FunctionMaker(object): ...@@ -1201,6 +1219,8 @@ class FunctionMaker(object):
self.accept_inplace = accept_inplace self.accept_inplace = accept_inplace
self.function_builder = function_builder self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used only for the pickling self.on_unused_input = on_unused_input # Used only for the pickling
self.output_dictionary_flag = output_dictionary_flag
self.output_keys = output_keys
self.required = [(i.value is None) for i in self.inputs] self.required = [(i.value is None) for i in self.inputs]
self.refeed = [ self.refeed = [
...@@ -1336,7 +1356,7 @@ class FunctionMaker(object): ...@@ -1336,7 +1356,7 @@ class FunctionMaker(object):
self.profile.import_time += import_time self.profile.import_time += import_time
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, self.return_none, self) defaults, self.unpack_single, self.return_none, self.output_dictionary_flag, self.output_keys, self)
fn.profile = self.profile fn.profile = self.profile
return fn return fn
...@@ -1396,7 +1416,8 @@ def register_checker(checker): ...@@ -1396,7 +1416,8 @@ def register_checker(checker):
def orig_function(inputs, outputs, mode=None, accept_inplace=False, def orig_function(inputs, outputs, mode=None, accept_inplace=False,
name=None, profile=None, on_unused_input=None): name=None, profile=None, on_unused_input=None,
output_dictionary_flag=False,output_keys=None):
""" """
Return a Function that will calculate the outputs from the inputs. Return a Function that will calculate the outputs from the inputs.
...@@ -1462,7 +1483,9 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1462,7 +1483,9 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
mode, mode,
accept_inplace=accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input).create( on_unused_input=on_unused_input,
output_dictionary_flag = output_dictionary_flag,
output_keys = output_keys).create(
defaults) defaults)
t2 = time.time() t2 = time.time()
......
...@@ -333,7 +333,7 @@ class Param(object): ...@@ -333,7 +333,7 @@ class Param(object):
self.implicit = implicit self.implicit = implicit
def pfunc(params, outputs=None, mode=None, updates=None, givens=None, def pfunc(params, output_dictionary_flag,output_keys, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None): profile=None, on_unused_input=None):
...@@ -506,7 +506,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -506,7 +506,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
mutable=False, borrow=True, shared=True) mutable=False, borrow=True, shared=True)
inputs.append(si) inputs.append(si)
return orig_function(inputs, cloned_outputs, mode, return orig_function(inputs, cloned_outputs, mode, output_dictionary_flag=output_dictionary_flag,output_keys = output_keys,
accept_inplace=accept_inplace, name=name, profile=profile, accept_inplace=accept_inplace, name=name, profile=profile,
on_unused_input=on_unused_input) on_unused_input=on_unused_input)
......
...@@ -16,6 +16,9 @@ try: ...@@ -16,6 +16,9 @@ try:
except ImportError: except ImportError:
flake8_available = False flake8_available = False
print theano.__file__
print "fake8available", flake8_available
whitelist_flake8 = [ whitelist_flake8 = [
"updates.py", "updates.py",
"printing.py", "printing.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论