提交 0e14d351 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

[scan][coding-style][doc] revampled the reduce function

上级 606b8ce7
...@@ -123,48 +123,66 @@ def map( fn ...@@ -123,48 +123,66 @@ def map( fn
, name = name ) , name = name )
def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False, # The ``reduce`` view of Scan Op.
mode = None, name = None): def reduce( fn
""" Similar behaviour as python reduce , sequences
, outputs_info
, non_sequences = None
, go_backwards = False
, mode = None
, name = None ):
"""
Similar behaviour as python's reduce
:param fn: the function to be applied over the elements in :param fn: The function that ``reduce`` applies at each iteration step
sequences ( see scan `fn` for more info) (see ``scan`` for more info).
:param outputs_info: information about outputs (mainly the initial state :param sequences: List of sequences over which ``reduce`` iterates
of each, but other options are available ), see scan for more (see ``scan`` for more info)
info
:param sequences: list of arrays over which reduce should :param outputs_info: List of dictionaries describing the outputs of
iterate (see scan for more info) reduce (see ``scan`` for more info).
:param non_sequences: list of other arguments of `fn` over which :param non_sequences: List of arguments passed to ``fn``. ``reduce`` will
reduce shouldn't iterate (see scan for more info) not iterate over these arguments (see ``scan`` for
more info).
:param go_backwards: set to true if you want map to start at the end of the :param go_backwards: Boolean value that decides the direction of
provided arrays in ``sequences`` going towards 0 (back in time) iteration. True means that sequences are parsed
from the end towards the begining, while False
is the other way around.
:param mode: see scan :param mode: See ``scan``.
:param name: see scan
:param name: See ``scan``.
""" """
# Specify that you only want the last value of the output # Makes sure the outputs_info is a list.
if type(outputs_info) not in (list,tuple): if type(outputs_info) not in (list,tuple):
outs_info = [outputs_info] outs_info = [outputs_info]
else: else:
outs_info = outputs_info outs_info = list(outputs_info)
for i,out_info in enumerate(outs_info): for i,out_info in enumerate(outs_info):
if out_info: if out_info:
if not type(out_info) == dict: if not type(out_info) == dict:
outs_info[i] = dict(initial = out_info, return_steps = 1) # Specifies that it should return only the last step.
outs_info[i] = dict(
initial = out_info, return_steps = 1, store_steps = 1)
else: else:
# we tell scan to store only the last step # Specifies that it should return only the last step.
# this will implicitly tell scan to also return just that
outs_info[i]['store_steps'] = 1 outs_info[i]['store_steps'] = 1
# NOTE : Maybe some errors can be detected here and outs_info[i]['return_steps'] = 1
# we could give more meaningfull error messages then in scan ? # NOTE : If the user asks for more then the last step,
return scan(fn, sequences = sequences, outputs_info = outs_info, # it means he does not understand ``reduce``. We could
non_sequences = non_sequences, go_backwards = go_backwards, # issue a warning in that case
truncate_gradient = 1, mode = mode, name = name) return scan( fn = fn
, sequences = sequences
, outputs_info = outs_info
, non_sequences = non_sequences
, go_backwards = go_backwards
, truncate_gradient = 1
, mode = mode
, name = name )
def foldl(fn, sequences, outputs_info, non_sequences = [], mode = None, name = None): def foldl(fn, sequences, outputs_info, non_sequences = [], mode = None, name = None):
""" Similar behaviour as haskell foldl """ Similar behaviour as haskell foldl
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论