提交 6e28e1b2 authored 作者: AlexLamb's avatar AlexLamb

Update to fix issues with code review

上级 2d9dc557
...@@ -186,7 +186,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -186,7 +186,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
""" """
if isinstance(outputs, dict): if isinstance(outputs, dict):
outputs_dict_format = True
output_items = outputs.items() output_items = outputs.items()
output_items_sorted = sorted(output_items) output_items_sorted = sorted(output_items)
...@@ -200,7 +199,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -200,7 +199,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
else: else:
outputs_dict_format = False
output_keys = None output_keys = None
if name is None: if name is None:
...@@ -282,7 +280,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -282,7 +280,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
allow_input_downcast=allow_input_downcast, allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
profile=profile, profile=profile,
output_dictionary_flag=outputs_dict_format,
output_keys=output_keys) output_keys=output_keys)
# We need to add the flag check_aliased inputs if we have any mutable or # We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs # borrowed used defined inputs
......
...@@ -9,7 +9,6 @@ import cPickle ...@@ -9,7 +9,6 @@ import cPickle
import itertools import itertools
import time import time
import warnings import warnings
import numpy import numpy
import theano import theano
...@@ -280,8 +279,7 @@ class Function(object): ...@@ -280,8 +279,7 @@ class Function(object):
""" """
def __init__(self, fn, input_storage, output_storage, indices, outputs, def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, output_dictionary_flag, defaults, unpack_single, return_none, output_keys, maker):
output_keys, maker):
""" """
Initialize attributes. create finder, inv_finder. Initialize attributes. create finder, inv_finder.
""" """
...@@ -299,7 +297,6 @@ class Function(object): ...@@ -299,7 +297,6 @@ class Function(object):
self.trust_input = False # If True, we don't check the input parameter self.trust_input = False # If True, we don't check the input parameter
self.name = None self.name = None
self.nodes_with_inner_function = [] self.nodes_with_inner_function = []
self.output_dictionary_flag = output_dictionary_flag
self.output_keys = output_keys self.output_keys = output_keys
# We will be popping stuff off this `containers` object. It is a copy. # We will be popping stuff off this `containers` object. It is a copy.
...@@ -680,9 +677,10 @@ class Function(object): ...@@ -680,9 +677,10 @@ class Function(object):
return outputs[0] return outputs[0]
else: else:
if self.output_dictionary_flag: if self.output_keys != None:
outputDict = {} outputDict = {}
assert len(self.output_keys) == len(outputs) assert len(self.output_keys) == len(outputs)
for i in range(0, len(self.output_keys)): for i in range(0, len(self.output_keys)):
...@@ -1062,7 +1060,7 @@ class FunctionMaker(object): ...@@ -1062,7 +1060,7 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs, def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function, mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None, profile=None, on_unused_input=None, fgraph=None,
output_dictionary_flag=False,output_keys=None): output_keys=None):
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
...@@ -1216,7 +1214,6 @@ class FunctionMaker(object): ...@@ -1216,7 +1214,6 @@ class FunctionMaker(object):
self.accept_inplace = accept_inplace self.accept_inplace = accept_inplace
self.function_builder = function_builder self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used only for the pickling self.on_unused_input = on_unused_input # Used only for the pickling
self.output_dictionary_flag = output_dictionary_flag
self.output_keys = output_keys self.output_keys = output_keys
self.required = [(i.value is None) for i in self.inputs] self.required = [(i.value is None) for i in self.inputs]
...@@ -1353,7 +1350,7 @@ class FunctionMaker(object): ...@@ -1353,7 +1350,7 @@ class FunctionMaker(object):
self.profile.import_time += import_time self.profile.import_time += import_time
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, self.return_none, self.output_dictionary_flag, self.output_keys, self) defaults, self.unpack_single, self.return_none, self.output_keys, self)
fn.profile = self.profile fn.profile = self.profile
return fn return fn
...@@ -1414,7 +1411,7 @@ def register_checker(checker): ...@@ -1414,7 +1411,7 @@ def register_checker(checker):
def orig_function(inputs, outputs, mode=None, accept_inplace=False, def orig_function(inputs, outputs, mode=None, accept_inplace=False,
name=None, profile=None, on_unused_input=None, name=None, profile=None, on_unused_input=None,
output_dictionary_flag=False,output_keys=None): output_keys=None):
""" """
Return a Function that will calculate the outputs from the inputs. Return a Function that will calculate the outputs from the inputs.
...@@ -1481,7 +1478,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1481,7 +1478,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
accept_inplace=accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_dictionary_flag = output_dictionary_flag,
output_keys = output_keys).create( output_keys = output_keys).create(
defaults) defaults)
......
...@@ -333,10 +333,10 @@ class Param(object): ...@@ -333,10 +333,10 @@ class Param(object):
self.implicit = implicit self.implicit = implicit
def pfunc(params, output_dictionary_flag,output_keys, outputs=None, mode=None, updates=None, givens=None, def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None): profile=None, on_unused_input=None,output_keys=None):
"""Function-constructor for graphs with shared variables. """Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances. :type params: list of either Variable or Param instances.
...@@ -506,9 +506,9 @@ def pfunc(params, output_dictionary_flag,output_keys, outputs=None, mode=None, u ...@@ -506,9 +506,9 @@ def pfunc(params, output_dictionary_flag,output_keys, outputs=None, mode=None, u
mutable=False, borrow=True, shared=True) mutable=False, borrow=True, shared=True)
inputs.append(si) inputs.append(si)
return orig_function(inputs, cloned_outputs, mode, output_dictionary_flag=output_dictionary_flag,output_keys = output_keys, return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile, accept_inplace=accept_inplace, name=name, profile=profile,
on_unused_input=on_unused_input) on_unused_input=on_unused_input, output_keys=output_keys)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论