提交 009177d7 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

New argument `return_list` to PureOp.__call__()

This makes it possible to obtain the Op's outputs as a list without having to go through ``make_node(..).outputs``.
上级 346f651f
...@@ -373,7 +373,16 @@ class PureOp(object): ...@@ -373,7 +373,16 @@ class PureOp(object):
`default_output`, but subclasses are free to override this function and ignore `default_output`, but subclasses are free to override this function and ignore
`default_output`. `default_output`.
:param inputs: The Op's inputs, forwarded to the call to `make_node()`.
:param kwargs: Additional keyword arguments to be forwarded to
`make_node()` *except* for optional argument `return_list` (which
defaults to False). If `return_list` is True, then the returned
value is always a list. Otherwise it is either a single Variable
when the output of `make_node()` contains a single element, or this
output (unchanged) when it contains multiple elements.
""" """
return_list = kwargs.pop('return_list', False)
node = self.make_node(*inputs, **kwargs) node = self.make_node(*inputs, **kwargs)
if self.add_stack_trace_on_call: if self.add_stack_trace_on_call:
self.add_tag_trace(node) self.add_tag_trace(node)
...@@ -434,9 +443,14 @@ class PureOp(object): ...@@ -434,9 +443,14 @@ class PureOp(object):
output.tag.test_value = storage_map[output][0] output.tag.test_value = storage_map[output][0]
if self.default_output is not None: if self.default_output is not None:
return node.outputs[self.default_output] rval = node.outputs[self.default_output]
if return_list:
rval = [rval]
return rval
else: else:
if len(node.outputs) == 1: if return_list:
return list(node.outputs)
elif len(node.outputs) == 1:
return node.outputs[0] return node.outputs[0]
else: else:
return node.outputs return node.outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论