提交 18205006 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

when None is given as the outputs list, function and Method now return None

上级 cfa74c57
......@@ -170,6 +170,9 @@ class Function(object):
unpack_single = None
"""Bool: for outputs lists of length 1, should the 0'th element be returned directly?"""
return_none = None
"""Bool: whether the function should return None or not"""
maker = None
"""FunctionMaker instance"""
......@@ -197,7 +200,7 @@ class Function(object):
It maps container -> SymbolicInput
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, maker):
def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, return_none, maker):
"""
Initialize attributes. create finder, inv_finder.
"""
......@@ -209,6 +212,7 @@ class Function(object):
self.outputs = outputs
self.defaults = defaults
self.unpack_single = unpack_single
self.return_none = return_none
self.maker = maker
# we'll be popping stuff off this `containers` object. It's a copy
......@@ -379,7 +383,9 @@ class Function(object):
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
self[i] = value
if self.unpack_single and len(outputs) == 1:
if self.return_none:
return None
elif self.unpack_single and len(outputs) == 1:
return outputs[0]
else:
return outputs
......@@ -571,6 +577,10 @@ class FunctionMaker(object):
# Handle the case where inputs and/or outputs is a single Variable (not in a list)
unpack_single = False
return_none = False
if outputs is None:
return_none = True
outputs = []
if not isinstance(outputs, (list, tuple)):
unpack_single = True
outputs = [outputs]
......@@ -611,6 +621,7 @@ class FunctionMaker(object):
self.expanded_inputs = expanded_inputs
self.outputs = outputs
self.unpack_single = unpack_single
self.return_none = return_none
self.mode = mode
self.accept_inplace = accept_inplace
self.function_builder = function_builder
......@@ -706,12 +717,13 @@ class FunctionMaker(object):
# Get a function instance
_fn, _i, _o = self.linker.make_thunk(input_storage = input_storage)
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, defaults, self.unpack_single, self)
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, defaults, self.unpack_single, self.return_none, self)
return fn
def _pickle_FunctionMaker(fm):
rval = (_constructor_FunctionMaker, (fm.inputs, fm.outputs[0] if fm.unpack_single else fm.outputs, fm.mode, fm.accept_inplace))
outputs = None if fm.return_none else (fm.outputs[0] if fm.unpack_single else fm.outputs)
rval = (_constructor_FunctionMaker, (fm.inputs, outputs, fm.mode, fm.accept_inplace))
return rval
def _constructor_FunctionMaker(*args):
......@@ -794,9 +806,7 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
mode = mode if mode is not None else mode_module.default_mode
inputs = map(convert_function_input, inputs)
if outputs is None:
outputs = []
else:
if outputs is not None:
outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs)
defaults = [getattr(input, 'value', None) for input in inputs]
......
......@@ -347,8 +347,6 @@ class Method(Component):
"""
super(Method, self).__init__()
if outputs is None:
outputs = []
self.inputs = inputs
self.outputs = outputs
self.updates = dict(updates)
......@@ -483,7 +481,7 @@ class Method(Component):
outputs = self.outputs
_inputs = [x.variable for x in inputs]
# Grab the variables that are not accessible from either the inputs or the updates.
outputs_list = list(outputs) if isinstance(outputs, (list, tuple)) else [outputs]
outputs_list = [] if outputs is None else (list(outputs) if isinstance(outputs, (list, tuple)) else [outputs])
outputs_variable_list = [o.variable if isinstance(o, io.Out) else o for o in outputs_list]
for input in gof.graph.inputs(outputs_variable_list
+ [x.update for x in inputs if getattr(x, 'update', False)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论