提交 b9344a3c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Function to deal with the case when n_steps is 1

上级 de1377e7
......@@ -320,3 +320,43 @@ def scan(fn,
# Note : see the internal documentation of the scan op for naming
# conventions and all other details
def one_step_scan(fn, us, xys_info, ws, T, truncate_gradient):
"""
This function is evaluated if `n_steps` evaluates to either 1 or -1.
"""
# 1. Grab slices of sequences
us_slices = [u[0] for u in us]
# 2. Grab slices of states
xs_slices = []
for n, x in enumerate(xys_info):
if x.get('taps', None) == [-1]:
xs_slices.append(x['initial'])
elif init_out.get('taps', None):
if numpy.any(numpy.array(init_out.get('taps', [])) > 0):
# Make sure we do not have requests for future values of a
# sequence we can not provide such values
raise ValueError('Can not use future taps of outputs',
init_out)
# go through the taps
mintap = abs(numpy.min(init_out['taps']))
xs_slices.append(x['initial'][k+mintap])
# Re-order args
args = (us_slices + xs_slices + non_seqs)
cond, xys_results, updates = scan_utils.get_updates_and_outputs(fn(*args))
# We do not need to use the scan op anymore, so we can just return
# the outputs and updates we have
if cond is not None:
_logger.warning(('When the number of steps is fixed and equal '
'to 1, the provided stopping condition, ',
str(cond), ' is ignored'))
xys_results = [tensor.unbroadcast(
tensor.shape_padleft(xy_results), 0) for xy in xys]
if len(xys) == 1:
xys_results = xys_results[0]
return (xys_results, updates)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论