提交 07f12948 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

better naming convention

上级 b86abee4
...@@ -319,12 +319,12 @@ def scan(fn, ...@@ -319,12 +319,12 @@ def scan(fn,
""" """
# Note : see the internal documentation of the scan op for naming # Note : see the internal documentation of the scan op for naming
# conventions and all other details # conventions and all other details
us, xys_info, ws, T = scan_utils.canonical_arguments(sequences, rvals = scan_utils.canonical_arguments(sequences,
outputs_info, outputs_info,
non_sequences, non_sequences,
go_backwards, go_backwards,
n_steps) n_steps)
inputs, states_and_outputs_info, parameters, T = rvals
# If we provided a known number of steps ( before compilation) # If we provided a known number of steps ( before compilation)
# and if that number is 1 or -1, then we can skip the Scan Op, # and if that number is 1 or -1, then we can skip the Scan Op,
# and just apply the inner function once # and just apply the inner function once
...@@ -340,9 +340,9 @@ def scan(fn, ...@@ -340,9 +340,9 @@ def scan(fn,
if T_value in (1, -1): if T_value in (1, -1):
return one_step_scan(fn, return one_step_scan(fn,
us, inputs,
xys_info, states_and_outputs_info,
ws, parameters,
T_value, T_value,
truncate_gradient) truncate_gradient)
...@@ -352,22 +352,23 @@ def scan(fn, ...@@ -352,22 +352,23 @@ def scan(fn,
# 2. Allocate memory for the states of scan. # 2. Allocate memory for the states of scan.
mintaps = [] mintaps = []
lengths = [] lengths = []
for xy in xys_info: for arg_info in states_and_outputs_info:
if xy.get('taps', None) == [-1]: if arg_info.get('taps', None) == [-1]:
mintaps.append(1) mintaps.append(1)
lengths.append(scalar_shared(numpy.int64(0))) lengths.append(scalar_shared(numpy.int64(0)))
xy['initial'] = scan_utils.expand(tensor.unbroadcast( arg_info['initial'] = scan_utils.expand(tensor.unbroadcast(
tensor.shape_padfelt(xy['initial'], 0), T)) tensor.shape_padfelt(state['initial'], 0), T))
elif xy.get('taps', None): elif arg_info.get('taps', None):
if numpy.any(numpy.array(xy.get('taps', [])) > 0): if numpy.any(numpy.array(arg_info.get('taps', [])) > 0):
# Make sure we do not have requests for future values of a # Make sure we do not have requests for future values of a
# sequence we can not provide such values # sequence we can not provide such values
raise ValueError('Can not use future taps of outputs', raise ValueError('Can not use future taps of outputs',
init_out) arg_info)
mintap = abs(numpy.min(xy['taps'])) mintap = abs(numpy.min(arg_info['taps']))
lengths.append(scalar_shared(numpy.int64(0))) lengths.append(scalar_shared(numpy.int64(0)))
mintaps.append(mintap) mintaps.append(mintap)
xy['initial'] = scan_utils.expand(xy['initial'][:mintap], T) arg_info['initial'] = scan_utils.expand(
arg_info['initial'][:mintap], T)
else: else:
mintaps.append(0) mintaps.append(0)
lengths.append(scalar_shared(numpy.int64(0))) lengths.append(scalar_shared(numpy.int64(0)))
...@@ -375,34 +376,37 @@ def scan(fn, ...@@ -375,34 +376,37 @@ def scan(fn,
# 3. Generate arguments for the function passed to scan. This will # 3. Generate arguments for the function passed to scan. This will
# function will return the outputs that need to be computed at every # function will return the outputs that need to be computed at every
# timesteps # timesteps
us_slices = [u[t] for u in us] inputs_slices = [input[t] for input in inputs]
xs_slices = [] states_slices = []
for n, xy in enumerate(xys_info): for n, state in enumerate(states_and_outputs_info):
# Check if it is actually a state and not an output
if mintaps[n] != 0: if mintaps[n] != 0:
for k in init_out['taps']: for k in state['taps']:
xs_slices.append( states_slices.append(
xy['initial'][(t + mintaps[n] - k) % lengths[n]]) state['initial'][(t + mintaps[n] - k) % lengths[n]])
# 4. Construct outputs that are to be computed by the inner # 4. Construct outputs that are to be computed by the inner
# function of scan # function of scan
args = us_slices + xs_slices + ws args = inputs_slices + states_slices + parameters
cond, xys_results, updates = scan_utils.get_updates_and_outputs(fn(*args)) cond, states_and_outputs, updates = \
scan_utils.get_updates_and_outputs(fn(*args))
if cond is not None: if cond is not None:
as_while = True as_repeatUntil = True
else: else:
as_while = False as_repeatUntil = False
# User is allowed to provide no information if it only behaves like a # User is allowed to provide no information if it only behaves like a
# map # map
if len(xys_outputs) != len(xys_info) and len(xys_info) == 0: if (len(states_and_outputs) != len(states_and_outputs_info) and
xys_info = [None] * len(xys_outputs) len(states_and_outputs_info) == 0):
states_and_outputs_info = [None] * len(states_and_outputs)
# 5. Construct the scan op # 5. Construct the scan op
# 5.1 Construct list of shared variables with updates (those that # 5.1 Construct list of shared variables with updates (those that
# can be treated as states (i.e. of TensorType) and those that can not # can be treated as states (i.e. of TensorType) and those that can not
# (like Random States) # (like Random States)
rvals = rebuild_collect_shared( rvals = rebuild_collect_shared(
xys_results + [cond], states_and_outputs + [cond],
updates=updates, updates=updates,
rebuild_strict=True, rebuild_strict=True,
copy_inputs_over=True, copy_inputs_over=True,
...@@ -411,80 +415,87 @@ def scan(fn, ...@@ -411,80 +415,87 @@ def scan(fn,
# extracting the arguments # extracting the arguments
input_variables, cloned_outputs, other_rval = rvals input_variables, cloned_outputs, other_rval = rvals
clone_d, update_d, update_expr, shared_inputs = other_rval clone_d, update_d, update_expr, shared_inputs = other_rval
additional_xs_outer = [] additional_input_states_outer = []
additional_xs_inner = [] additional_input_states_inner = []
additional_xs_results = [] additional_output_states = []
additional_lengths = [] additional_lengths = []
zs_outer = [] non_numeric_input_states_outer = []
zs_inner = [] non_numeric_input_states_inner = []
zs_results = [] non_numeric_output_states = []
for sv in shared_inputs: for sv in shared_inputs:
if sv in update_d: if sv in update_d:
if isinstance(sv, TensorType): if isinstance(sv, TensorType):
# We can treat it as a sit sot # We can treat it as a sit sot
nw_x = scan_utils.expand( nw_state = scan_utils.expand(
tensor.unbroadcast( tensor.unbroadcast(tensor.shape_padleft(sv, 0), T))
tensor.shape_padleft(sv, 0), actual_n_steps))
additional_lengths.append(scalar_shared(numpy.int64(0))) additional_lengths.append(scalar_shared(numpy.int64(0)))
additional_xs_outer.append(nw_x) additional_input_states_outer.append(nw_state)
additional_xs_inner.append(nw_x.type()) additional_input_states_inner.append(nw_state.type())
additional_xs_results.append( additional_output_states.append(
scan_utils.clone(tensor.set_subtensor( scan_utils.clone(tensor.set_subtensor(
nw_x[(t + 1) % additional_lengths[-1]], nw_state[(t + 1) % additional_lengths[-1]],
update_d[sv]))) update_d[sv])))
else: else:
zs_outer.append(sv) non_numeric_input_states_outer.append(sv)
zs_inner.append(sv.type()) non_numeric_input_states_inner.append(sv.type())
zs_results.append(update_d[sv]) non_numeric_output_states.append(update_d[sv])
# 5.2 Collect and order inputs of the inner function # 5.2 Collect and order inputs of the inner function
xs_outer = [] input_states_outer = []
xs_results = [] output_states = []
ys_outer = [] memory_buffers_for_outputs = []
ys_results = [] outputs = []
for n, mintap in enumerate(mintaps): for n, mintap in enumerate(mintaps):
if mintap != 0: if mintap != 0:
x = xys_info[n]['initial'] input_state = states_and_outputs_info[n]['initial']
xs_outer.append(x) input_states_outer.append(input_state)
xs_results.append( output_states.append(
tensor.set_subtensor(x[(t + 1) % lengths[n]], tensor.set_subtensor(input_state[(t + 1) % lengths[n]],
xys_results[n])) states_and_outputs[n]))
else: else:
y = scan_utils.allocate_memory(T, xys_info[n], xys_results[n]) output = scan_utils.allocate_memory(
ys_outer.append(y) T, states_and_outputs_info[n], states_and_outputs[n])
ys_results.append( memory_buffers_for_outputs.append(output)
tensor.set_subtensor(y[t % lengths[n]], xys_results[n]) outputs.append(
tensor.set_subtensor(output[t % lengths[n]],
states_and_outputs[n])
# 5.3 Construct the scan op # 5.3 Construct the scan op
def one_step_scan(fn, us, xys_info, ws, T, truncate_gradient): def one_step_scan(fn,
inputs,
states_and_outputs_info,
parameters,
T,
truncate_gradient):
""" """
This function is evaluated if `n_steps` evaluates to either 1 or -1. This function is evaluated if `n_steps` evaluates to either 1 or -1.
""" """
# 1. Grab slices of sequences # 1. Grab slices of sequences
us_slices = [u[0] for u in us] inputs_slices = [input[0] for input in inputs]
# 2. Grab slices of states # 2. Grab slices of states
xs_slices = [] states_slices = []
for n, x in enumerate(xys_info): for n, arg_info in enumerate(states_and_outputs_info):
if x.get('taps', None) == [-1]: if arg_info.get('taps', None) == [-1]:
xs_slices.append(x['initial']) states_slices.append(arg_info['initial'])
elif init_out.get('taps', None): elif arg_info.get('taps', None):
if numpy.any(numpy.array(init_out.get('taps', [])) > 0): if numpy.any(numpy.array(arg_info.get('taps', [])) > 0):
# Make sure we do not have requests for future values of a # Make sure we do not have requests for future values of a
# sequence we can not provide such values # sequence we can not provide such values
raise ValueError('Can not use future taps of outputs', raise ValueError('Can not use future taps of outputs',
init_out) arg_info)
# go through the taps # go through the taps
mintap = abs(numpy.min(init_out['taps'])) mintap = abs(numpy.min(arg_info['taps']))
xs_slices.append(x['initial'][k+mintap]) states_slices.append(arg_info['initial'][k+mintap])
# Re-order args # Re-order args
args = (us_slices + xs_slices + non_seqs) args = (inputs_slices + states_slices + parameters)
cond, xys_results, updates = scan_utils.get_updates_and_outputs(fn(*args)) cond, states_and_outputs, updates = \
scan_utils.get_updates_and_outputs(fn(*args))
# We do not need to use the scan op anymore, so we can just return # We do not need to use the scan op anymore, so we can just return
# the outputs and updates we have # the outputs and updates we have
...@@ -492,9 +503,9 @@ def one_step_scan(fn, us, xys_info, ws, T, truncate_gradient): ...@@ -492,9 +503,9 @@ def one_step_scan(fn, us, xys_info, ws, T, truncate_gradient):
_logger.warning(('When the number of steps is fixed and equal ' _logger.warning(('When the number of steps is fixed and equal '
'to 1, the provided stopping condition, ', 'to 1, the provided stopping condition, ',
str(cond), ' is ignored')) str(cond), ' is ignored'))
xys_results = [tensor.unbroadcast( states_and_outputs = [tensor.unbroadcast(
tensor.shape_padleft(xy_results), 0) for xy in xys] tensor.shape_padleft(arg), 0) for arg in states_and_outputs]
if len(xys) == 1: if len(states_and_outputs) == 1:
xys_results = xys_results[0] states_and_outputs = states_and_outputs[0]
return (xys_results, updates) return (states_and_outputs, updates)
...@@ -183,20 +183,19 @@ def canonical_arguments(sequences, ...@@ -183,20 +183,19 @@ def canonical_arguments(sequences,
and that the different fields of of a dictionary are set to default and that the different fields of of a dictionary are set to default
value if the user has not provided any. value if the user has not provided any.
""" """
us = to_list(sequences) states_info = to_list(outputs_info)
xys_info = to_list(outputs_info) parameters = [tensor.as_tensor_variable(x) for x in to_list(non_sequences)]
ws = [tensor.as_tensor_variable(x) for x in to_list(non_sequences)]
us = [] inputs = []
for u in to_list(sequences): for input in to_list(sequences):
if not isinstance(u, dict): if not isinstance(u, dict):
us.append(u) inputs.append(input)
elif u.get('taps', True) is None: elif input.get('taps', True) is None:
us.append(u) inputs.append(input)
elif u.get('taps', None): elif input.get('taps', None):
mintap = numpy.min(u['taps']) mintap = numpy.min(input['taps'])
maxtap = numpy.max(u['taps']) maxtap = numpy.max(input['taps'])
for k in u['taps']: for k in input['taps']:
# We cut the sequence such that seq[i] to correspond to # We cut the sequence such that seq[i] to correspond to
# seq[i-k] # seq[i-k]
if maxtap < 0: if maxtap < 0:
...@@ -204,57 +203,57 @@ def canonical_arguments(sequences, ...@@ -204,57 +203,57 @@ def canonical_arguments(sequences,
else: else:
offset = 0 offset = 0
if maxtap == mintap and maxtap != 0: if maxtap == mintap and maxtap != 0:
nw_u = u['input'][:abs(maxtap)] nw_input = input['input'][:abs(maxtap)]
elif maxtap - k != 0: elif maxtap - k != 0:
nw_u = u['input'][offset + k - mintap: -(maxtap - k)] nw_input = input['input'][offset + k - mintap: -(maxtap - k)]
else: else:
nw_u = u['input'][offset + k - mintap:] nw_input = input['input'][offset + k - mintap:]
if go_backwards: if go_backwards:
nw_u = nw_u[::-1] nw_input = nw_input[::-1]
us.append(nw_u) inputs.append(nw_input)
else: else:
raise ValueError('Provided sequence makes no sense', str(u)) raise ValueError('Provided sequence makes no sense', str(input))
# Since we've added all sequences now we need to level them up based on # Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes # n_steps or their different shapes
if n_steps is None: if n_steps is None:
if len(us) == 0: if len(inputs) == 0:
# No information about the number of steps # No information about the number of steps
raise ValueError('You need to provide either at least ' raise ValueError('You need to provide either at least '
'one sequence over which scan should loop ' 'one sequence over which scan should loop '
'or a number of steps for scan to loop. ' 'or a number of steps for scan to loop. '
'Neither of the two had been provided !') 'Neither of the two had been provided !')
T = us[0].shape[0] T = inputs[0].shape[0]
for u in us[1:]: for input in inputs[1:]:
T = tensor.minimum(T, u.shape[0]) T = tensor.minimum(T, input.shape[0])
else: else:
T = tensor.as_tensor(n_steps) T = tensor.as_tensor(n_steps)
# Level up sequences # Level up sequences
us = [u[:T] for u in us] inputs = [input[:T] for input in inputs]
# wrap outputs info in a dictionary if they are not already in one # wrap outputs info in a dictionary if they are not already in one
for i, xy in enumerate(xys_info): for i, state in enumerate(states_info):
if xy is not None and not isinstance(xy, dict): if state is not None and not isinstance(state, dict):
xys_info[i] = dict(initial=xy, taps=[-1]) states_info[i] = dict(initial=state, taps=[-1])
elif isinstance(xy, dict): elif isinstance(state, dict):
if not xy.get('initial', None) and xy.get('taps', None): if not state.get('initial', None) and state.get('taps', None):
raise ValueError(('If you are using slices of an output ' raise ValueError(('If you are using slices of an output '
'you need to provide a initial state ' 'you need to provide a initial state '
'for it'), xy) 'for it'), state)
elif xy.get('initial', None) and not xy.get('taps', None): elif state.get('initial', None) and not state.get('taps', None):
# ^ initial state but taps not provided # ^ initial state but taps not provided
if 'taps' in xy: if 'taps' in state:
# ^ explicitly provided a None for taps # ^ explicitly provided a None for taps
_logger.warning('Output %s ( index %d) has a initial ' _logger.warning('Output %s ( index %d) has a initial '
'state but taps is explicitly set to None ', 'state but taps is explicitly set to None ',
getattr(outs_info[i]['initial'], 'name', 'None'), getattr(states_info[i]['initial'], 'name', 'None'),
i) i)
xys_info[i]['taps'] = [-1] states_info[i]['taps'] = [-1]
else: else:
# if a None is provided as the output info we replace it # if a None is provided as the output info we replace it
# with an empty dict() to simplify handling # with an empty dict() to simplify handling
xys_info[i] = dict() states_info[i] = dict()
return seqs, outs_info, non_seqs, actual_n_steps return inputs, staess_info, parameters, T
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论