提交 0e14d351 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

[scan][coding-style][doc] revampled the reduce function

上级 606b8ce7
......@@ -123,48 +123,66 @@ def map( fn
, name = name )
def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False,
mode = None, name = None):
""" Similar behaviour as python reduce
:param fn: the function to be applied over the elements in
sequences ( see scan `fn` for more info)
# The ``reduce`` view of Scan Op.
def reduce( fn
, sequences
, outputs_info
, non_sequences = None
, go_backwards = False
, mode = None
, name = None ):
"""
Similar behaviour as python's reduce
:param outputs_info: information about outputs (mainly the initial state
of each, but other options are available ), see scan for more
info
:param fn: The function that ``reduce`` applies at each iteration step
(see ``scan`` for more info).
:param sequences: list of arrays over which reduce should
iterate (see scan for more info)
:param sequences: List of sequences over which ``reduce`` iterates
(see ``scan`` for more info)
:param non_sequences: list of other arguments of `fn` over which
reduce shouldn't iterate (see scan for more info)
:param outputs_info: List of dictionaries describing the outputs of
reduce (see ``scan`` for more info).
:param go_backwards: set to true if you want map to start at the end of the
provided arrays in ``sequences`` going towards 0 (back in time)
:param non_sequences: List of arguments passed to ``fn``. ``reduce`` will
not iterate over these arguments (see ``scan`` for
more info).
:param mode: see scan
:param name: see scan
:param go_backwards: Boolean value that decides the direction of
iteration. True means that sequences are parsed
from the end towards the begining, while False
is the other way around.
:param mode: See ``scan``.
:param name: See ``scan``.
"""
# Specify that you only want the last value of the output
# Makes sure the outputs_info is a list.
if type(outputs_info) not in (list,tuple):
outs_info = [outputs_info]
else:
outs_info = outputs_info
outs_info = list(outputs_info)
for i,out_info in enumerate(outs_info):
if out_info:
if not type(out_info) == dict:
outs_info[i] = dict(initial = out_info, return_steps = 1)
# Specifies that it should return only the last step.
outs_info[i] = dict(
initial = out_info, return_steps = 1, store_steps = 1)
else:
# we tell scan to store only the last step
# this will implicitly tell scan to also return just that
outs_info[i]['store_steps'] = 1
# NOTE : Maybe some errors can be detected here and
# we could give more meaningfull error messages then in scan ?
return scan(fn, sequences = sequences, outputs_info = outs_info,
non_sequences = non_sequences, go_backwards = go_backwards,
truncate_gradient = 1, mode = mode, name = name)
# Specifies that it should return only the last step.
outs_info[i]['store_steps'] = 1
outs_info[i]['return_steps'] = 1
# NOTE : If the user asks for more then the last step,
# it means he does not understand ``reduce``. We could
# issue a warning in that case
return scan( fn = fn
, sequences = sequences
, outputs_info = outs_info
, non_sequences = non_sequences
, go_backwards = go_backwards
, truncate_gradient = 1
, mode = mode
, name = name )
def foldl(fn, sequences, outputs_info, non_sequences = [], mode = None, name = None):
""" Similar behaviour as haskell foldl
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论