提交 19a501bb authored 作者: Razvan Pascanu's avatar Razvan Pascanu

code that prepares arguments for the scan op

上级 b9344a3c
...@@ -319,7 +319,145 @@ def scan(fn, ...@@ -319,7 +319,145 @@ def scan(fn,
""" """
# Note : see the internal documentation of the scan op for naming # Note : see the internal documentation of the scan op for naming
# conventions and all other details # conventions and all other details
us, xys_info, ws, T = scan_utils.canonical_arguments(sequences,
outputs_info,
non_sequences,
go_backwards,
n_steps)
# If we provided a known number of steps ( before compilation)
# and if that number is 1 or -1, then we can skip the Scan Op,
# and just apply the inner function once
# To do that we check here to see the nature of n_steps
T_value = None
if isinstance(n_steps, (float, int)):
T_value = int(n_steps)
else:
try:
T_value = opt.get_constant_value(n_steps)
except (TypeError, AttributeError):
T_value = None
if T_value in (1, -1):
return one_step_scan(fn,
us,
xys_info,
ws,
T_value,
truncate_gradient)
# 1. Variable representing the current time step
t = scalar_shared(numpy.int64(0))
# 2. Allocate memory for the states of scan.
mintaps = []
lengths = []
for xy in xys_info:
if xy.get('taps', None) == [-1]:
mintaps.append(1)
lengths.append(scalar_shared(numpy.int64(0)))
xy['initial'] = scan_utils.expand(tensor.unbroadcast(
tensor.shape_padfelt(xy['initial'], 0), T))
elif xy.get('taps', None):
if numpy.any(numpy.array(xy.get('taps', [])) > 0):
# Make sure we do not have requests for future values of a
# sequence we can not provide such values
raise ValueError('Can not use future taps of outputs',
init_out)
mintap = abs(numpy.min(xy['taps']))
lengths.append(scalar_shared(numpy.int64(0)))
mintaps.append(mintap)
xy['initial'] = scan_utils.expand(xy['initial'][:mintap], T)
else:
mintaps.append(0)
lengths.append(scalar_shared(numpy.int64(0)))
# 3. Generate arguments for the function passed to scan. This will
# function will return the outputs that need to be computed at every
# timesteps
us_slices = [u[t] for u in us]
xs_slices = []
for n, xy in enumerate(xys_info):
if mintaps[n] != 0:
for k in init_out['taps']:
xs_slices.append(
xy['initial'][(t + mintaps[n] - k) % lengths[n]])
# 4. Construct outputs that are to be computed by the inner
# function of scan
args = us_slices + xs_slices + ws
cond, xys_results, updates = scan_utils.get_updates_and_outputs(fn(*args))
if cond is not None:
as_while = True
else:
as_while = False
# User is allowed to provide no information if it only behaves like a
# map
if len(xys_outputs) != len(xys_info) and len(xys_info) == 0:
xys_info = [None] * len(xys_outputs)
# 5. Construct the scan op
# 5.1 Construct list of shared variables with updates (those that
# can be treated as states (i.e. of TensorType) and those that can not
# (like Random States)
rvals = rebuild_collect_shared(
xys_results + [cond],
updates=updates,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False)
# extracting the arguments
input_variables, cloned_outputs, other_rval = rvals
clone_d, update_d, update_expr, shared_inputs = other_rval
additional_xs_outer = []
additional_xs_inner = []
additional_xs_results = []
additional_lengths = []
zs_outer = []
zs_inner = []
zs_results = []
for sv in shared_inputs:
if sv in update_d:
if isinstance(sv, TensorType):
# We can treat it as a sit sot
nw_x = scan_utils.expand(
tensor.unbroadcast(
tensor.shape_padleft(sv, 0), actual_n_steps))
additional_lengths.append(scalar_shared(numpy.int64(0)))
additional_xs_outer.append(nw_x)
additional_xs_inner.append(nw_x.type())
additional_xs_results.append(
scan_utils.clone(tensor.set_subtensor(
nw_x[(t + 1) % additional_lengths[-1]],
update_d[sv])))
else:
zs_outer.append(sv)
zs_inner.append(sv.type())
zs_results.append(update_d[sv])
# 5.2 Collect and order inputs of the inner function
xs_outer = []
xs_results = []
ys_outer = []
ys_results = []
for n, mintap in enumerate(mintaps):
if mintap != 0:
x = xys_info[n]['initial']
xs_outer.append(x)
xs_results.append(
tensor.set_subtensor(x[(t + 1) % lengths[n]],
xys_results[n]))
else:
y = scan_utils.allocate_memory(T, xys_info[n], xys_results[n])
ys_outer.append(y)
ys_results.append(
tensor.set_subtensor(y[t % lengths[n]], xys_results[n])
# 5.3 Construct the scan op
def one_step_scan(fn, us, xys_info, ws, T, truncate_gradient): def one_step_scan(fn, us, xys_info, ws, T, truncate_gradient):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论