提交 6e28e1b2 authored 作者: AlexLamb's avatar AlexLamb

Update to fix issues with code review

上级 2d9dc557
......@@ -186,7 +186,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
"""
if isinstance(outputs, dict):
outputs_dict_format = True
output_items = outputs.items()
output_items_sorted = sorted(output_items)
......@@ -200,7 +199,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
else:
outputs_dict_format = False
output_keys = None
if name is None:
......@@ -282,7 +280,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
allow_input_downcast=allow_input_downcast,
on_unused_input=on_unused_input,
profile=profile,
output_dictionary_flag=outputs_dict_format,
output_keys=output_keys)
# We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs
......
......@@ -9,7 +9,6 @@ import cPickle
import itertools
import time
import warnings
import numpy
import theano
......@@ -280,8 +279,7 @@ class Function(object):
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, output_dictionary_flag,
output_keys, maker):
defaults, unpack_single, return_none, output_keys, maker):
"""
Initialize attributes. create finder, inv_finder.
"""
......@@ -299,7 +297,6 @@ class Function(object):
self.trust_input = False # If True, we don't check the input parameter
self.name = None
self.nodes_with_inner_function = []
self.output_dictionary_flag = output_dictionary_flag
self.output_keys = output_keys
# We will be popping stuff off this `containers` object. It is a copy.
......@@ -680,9 +677,10 @@ class Function(object):
return outputs[0]
else:
if self.output_dictionary_flag:
if self.output_keys != None:
outputDict = {}
assert len(self.output_keys) == len(outputs)
for i in range(0, len(self.output_keys)):
......@@ -1062,7 +1060,7 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None,
output_dictionary_flag=False,output_keys=None):
output_keys=None):
"""
:type inputs: a list of SymbolicInput instances
......@@ -1216,7 +1214,6 @@ class FunctionMaker(object):
self.accept_inplace = accept_inplace
self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used only for the pickling
self.output_dictionary_flag = output_dictionary_flag
self.output_keys = output_keys
self.required = [(i.value is None) for i in self.inputs]
......@@ -1353,7 +1350,7 @@ class FunctionMaker(object):
self.profile.import_time += import_time
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, self.return_none, self.output_dictionary_flag, self.output_keys, self)
defaults, self.unpack_single, self.return_none, self.output_keys, self)
fn.profile = self.profile
return fn
......@@ -1414,7 +1411,7 @@ def register_checker(checker):
def orig_function(inputs, outputs, mode=None, accept_inplace=False,
name=None, profile=None, on_unused_input=None,
output_dictionary_flag=False,output_keys=None):
output_keys=None):
"""
Return a Function that will calculate the outputs from the inputs.
......@@ -1481,7 +1478,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
accept_inplace=accept_inplace,
profile=profile,
on_unused_input=on_unused_input,
output_dictionary_flag = output_dictionary_flag,
output_keys = output_keys).create(
defaults)
......
......@@ -333,10 +333,10 @@ class Param(object):
self.implicit = implicit
def pfunc(params, output_dictionary_flag,output_keys, outputs=None, mode=None, updates=None, givens=None,
def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None):
profile=None, on_unused_input=None,output_keys=None):
"""Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances.
......@@ -506,9 +506,9 @@ def pfunc(params, output_dictionary_flag,output_keys, outputs=None, mode=None, u
mutable=False, borrow=True, shared=True)
inputs.append(si)
return orig_function(inputs, cloned_outputs, mode, output_dictionary_flag=output_dictionary_flag,output_keys = output_keys,
return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile,
on_unused_input=on_unused_input)
on_unused_input=on_unused_input, output_keys=output_keys)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论