提交 4ca744f0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up the variable names and signature in Scan's Cython code

上级 f89044a0
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -1405,10 +1405,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1405,10 +1405,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
] ]
else: else:
cython_destroy_map = [0 for x in range(len(node.outputs))] cython_destroy_map = [0 for x in range(len(node.outputs))]
cython_destroy_map = np.asarray(cython_destroy_map, dtype="int32") cython_destroy_map = np.asarray(cython_destroy_map, dtype="int32")
inner_input_storage = [s.storage for s in self.fn.input_storage]
inner_output_storage = [s.storage for s in self.fn.output_storage]
from . import scan_perform_ext from . import scan_perform_ext
def p(node, args, outs): def p(node, inputs, outputs):
return scan_perform_ext.perform( return scan_perform_ext.perform(
self.n_shared_outs, self.n_shared_outs,
self.n_mit_mot_outs, self.n_mit_mot_outs,
...@@ -1417,7 +1422,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1417,7 +1422,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_mit_sot, self.n_mit_sot,
self.n_sit_sot, self.n_sit_sot,
self.n_nit_sot, self.n_nit_sot,
args[0],
self.as_while, self.as_while,
cython_mintaps, cython_mintaps,
cython_tap_array, cython_tap_array,
...@@ -1425,15 +1429,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1425,15 +1429,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_vector_seqs, cython_vector_seqs,
cython_vector_outs, cython_vector_outs,
cython_mit_mot_out_slices, cython_mit_mot_out_slices,
cython_mit_mot_out_nslices,
cython_mitmots_preallocated, cython_mitmots_preallocated,
cython_inps_is_tensor, cython_inps_is_tensor,
cython_outs_is_tensor, cython_outs_is_tensor,
self.fn.fn, inner_input_storage,
inner_output_storage,
self.fn, self.fn,
cython_destroy_map, cython_destroy_map,
args, inputs,
outs, outputs,
self, self,
node, node,
) )
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
inner function as they are ( no slicing is performed) inner function as they are ( no slicing is performed)
All these outputs are one after the other in the inputs list (named in All these outputs are one after the other in the inputs list (named in
this code as args) in a given order ( namely the one described above this code as outer_inputs) in a given order ( namely the one described above
with little discrepancies depending if we are talking about the outputs with little discrepancies depending if we are talking about the outputs
of the Scan op or the inputs of the Scan op Node, and if we are talking of the Scan op or the inputs of the Scan op Node, and if we are talking
about the inputs of the inner function of scan or of the scan op). about the inputs of the inner function of scan or of the scan op).
...@@ -69,7 +69,6 @@ def perform( ...@@ -69,7 +69,6 @@ def perform(
unsigned int n_mit_sot, unsigned int n_mit_sot,
unsigned int n_sit_sot, unsigned int n_sit_sot,
unsigned int n_nit_sot, unsigned int n_nit_sot,
int n_steps,
bint as_while, bint as_while,
numpy.ndarray[numpy.int32_t,ndim=1] mintaps, numpy.ndarray[numpy.int32_t,ndim=1] mintaps,
numpy.ndarray[numpy.int32_t,ndim=2] tap_array, numpy.ndarray[numpy.int32_t,ndim=2] tap_array,
...@@ -77,15 +76,15 @@ def perform( ...@@ -77,15 +76,15 @@ def perform(
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs, numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs, numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices, numpy.ndarray[numpy.int32_t,ndim=2] mit_mot_out_slices,
numpy.ndarray[numpy.int32_t,ndim=1] mit_mot_out_nslices,
numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated, numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated,
numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor, numpy.ndarray[numpy.int32_t,ndim=1] inps_is_tensor,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor, numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
fn, list inner_input_storage,
list inner_output_storage,
fnct, fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map, numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
args, list outer_inputs,
outs, list outer_outputs,
self, self,
node): node):
""" """
...@@ -106,8 +105,6 @@ def perform( ...@@ -106,8 +105,6 @@ def perform(
Number of sit sot arguments Number of sit sot arguments
n_nit_sot: unsigned int n_nit_sot: unsigned int
Number of nit_sot arguments Number of nit_sot arguments
n_steps: unsigned int
Number of steps to loop over
mintaps: int32 ndarray (can also be a simple python list if that is better !) mintaps: int32 ndarray (can also be a simple python list if that is better !)
For any of the mit_mot, mit_sot, sit_sot says which is the furtherst For any of the mit_mot, mit_sot, sit_sot says which is the furtherst
away input tap from current position. For example, if the taps where [-2, away input tap from current position. For example, if the taps where [-2,
...@@ -131,29 +128,24 @@ def perform( ...@@ -131,29 +128,24 @@ def perform(
tensor, 0 otherwise. tensor, 0 otherwise.
mit_mot_out_slices : int32 ndarray( can be replaced by list of lists) mit_mot_out_slices : int32 ndarray( can be replaced by list of lists)
Same as tap_array, but for the output taps of mit_mot sequences Same as tap_array, but for the output taps of mit_mot sequences
mit_mot_out_nslices: int32 ndarray (Can be replaced by a list)
Same as tap_array_len, but is the number of output taps of the
mit_mot sequences (i.e. it corresponds to mit_mot_out_slices)
inps_is_tensor : int32 ndarray (Can be replaced by a list) inps_is_tensor : int32 ndarray (Can be replaced by a list)
Array of boolean indicating, for every input, whether it is a tensor Array of boolean indicating, for every input, whether it is a tensor
or not or not
outs_is_tensor : int32 ndarray (Can be replaced by a list) outs_is_tensor : int32 ndarray (Can be replaced by a list)
Array of boolean indicating, for every output, whether it is a tensor Array of boolean indicating, for every output, whether it is a tensor
or not or not
fn: callable inner_input_storage
This is the linker, i.e. the function that will loop over the The storage locations for the inner-function's inputs.
computational graph and call the perform of each operation. For this inner_output_storage
linker there is a c version in graph/lazy_linker.c that will be the The storage locations for the inner-function's outputs.
starting point of implementing this function in C ( we need to take
all the code around the call of this function and put in C inside
that code)
fnct: Function fnct: Function
The compiled Aesara inner-function object.
destroy_map destroy_map
Array of boolean saying if an output is computed inplace Array of boolean saying if an output is computed inplace
args: list of ndarrays (and random states) outer_inputs: list of ndarrays (and random states)
The inputs of scan in a given order ( n_steps, sequences, mit_mot, The inputs of scan in a given order ( n_steps, sequences, mit_mot,
mit_sot, sit_sot, nit_sot, shared_outs, other_args) mit_sot, sit_sot, nit_sot, shared_outs, other_args)
outs: list of 1 element list ( or storage objects?) outer_outputs: list of 1 element list ( or storage objects?)
This is where we need to copy our outputs ( we don't return the This is where we need to copy our outputs ( we don't return the
results, though we can change the code such that we return, and results, though we can change the code such that we return, and
figure things out on the outside - python) figure things out on the outside - python)
...@@ -166,6 +158,7 @@ def perform( ...@@ -166,6 +158,7 @@ def perform(
# negative flip sequences around, and make n_steps positive # negative flip sequences around, and make n_steps positive
t0_call = time.time() t0_call = time.time()
t_fn = 0 t_fn = 0
cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int seqs_arg_offset = n_seqs + 1 cdef unsigned int seqs_arg_offset = n_seqs + 1
cdef unsigned int shared_arg_offset = ( 1 + n_seqs + n_mit_mot + cdef unsigned int shared_arg_offset = ( 1 + n_seqs + n_mit_mot +
...@@ -207,13 +200,13 @@ def perform( ...@@ -207,13 +200,13 @@ def perform(
n_steps) n_steps)
else: else:
for idx in range(n_seqs): for idx in range(n_seqs):
if args[<unsigned int>(1+idx)].shape[0] < n_steps: if outer_inputs[<unsigned int>(1+idx)].shape[0] < n_steps:
raise ValueError(( raise ValueError((
"Sequence %s has shape %s " "Sequence %s has shape %s "
"but the Scan's required number of steps is %s" "but the Scan's required number of steps is %s"
) % ( ) % (
idx, idx,
args[1+idx].shape, outer_inputs[1+idx].shape,
n_steps, n_steps,
)) ))
# 2. Allocate memory for the outputs. Construct the list: # 2. Allocate memory for the outputs. Construct the list:
...@@ -221,11 +214,11 @@ def perform( ...@@ -221,11 +214,11 @@ def perform(
# pos -- map containing the current position of each output # pos -- map containing the current position of each output
for idx in range(n_mit_mot + n_mit_sot + n_sit_sot): for idx in range(n_mit_mot + n_mit_sot + n_sit_sot):
store_steps[<unsigned int>idx] = args[<unsigned int>(idx+n_seqs+1)].shape[0] store_steps[<unsigned int>idx] = outer_inputs[<unsigned int>(idx+n_seqs+1)].shape[0]
for idx in range(n_nit_sot): for idx in range(n_nit_sot):
store_steps[<unsigned int>(idx + n_mit_mot + n_mit_sot + n_sit_sot)]=\ store_steps[<unsigned int>(idx + n_mit_mot + n_mit_sot + n_sit_sot)]=\
args[<unsigned int>(idx + n_mit_mot + n_mit_sot + n_sit_sot outer_inputs[<unsigned int>(idx + n_mit_mot + n_mit_sot + n_sit_sot
+ n_shared_outs + n_seqs+1)] + n_shared_outs + n_seqs+1)]
# 2.1 Create storage space for outputs # 2.1 Create storage space for outputs
...@@ -233,20 +226,20 @@ def perform( ...@@ -233,20 +226,20 @@ def perform(
if destroy_map[idx] != 0: if destroy_map[idx] != 0:
# ^ Case 1. Outputs should be computed inplace of their # ^ Case 1. Outputs should be computed inplace of their
# initial state # initial state
outs[idx][0] = args[ <unsigned int>(1+ n_seqs + idx)] outer_outputs[idx][0] = outer_inputs[ <unsigned int>(1+ n_seqs + idx)]
elif ( outs[idx][0] is not None and elif ( outer_outputs[idx][0] is not None and
outs[idx][0].shape[1:] == args[<unsigned int>(1+ n_seqs + idx)].shape[1:] outer_outputs[idx][0].shape[1:] == outer_inputs[<unsigned int>(1+ n_seqs + idx)].shape[1:]
and outs[idx][0].shape[0] >= store_steps[idx] ): and outer_outputs[idx][0].shape[0] >= store_steps[idx] ):
# Put in the values of the initial state # Put in the values of the initial state
outs[idx][0] = outs[idx][0][:store_steps[idx]] outer_outputs[idx][0] = outer_outputs[idx][0][:store_steps[idx]]
if idx > n_mit_mot: if idx > n_mit_mot:
l = - mintaps[idx] l = - mintaps[idx]
outs[idx][0][:l] = args[<unsigned int>(seqs_arg_offset + outer_outputs[idx][0][:l] = outer_inputs[<unsigned int>(seqs_arg_offset +
idx)][:l] idx)][:l]
else: else:
outs[idx][0][:] = args[<unsigned int>(seqs_arg_offset + idx)] outer_outputs[idx][0][:] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)]
else: else:
outs[idx][0] = args[<unsigned int>(seqs_arg_offset + idx)].copy() outer_outputs[idx][0] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)].copy()
if n_steps == 0: if n_steps == 0:
for idx in range(n_outs, n_outs + n_nit_sot): for idx in range(n_outs, n_outs + n_nit_sot):
...@@ -256,23 +249,23 @@ def perform( ...@@ -256,23 +249,23 @@ def perform(
# (The answer is that you shouldn't have a `node` object to # (The answer is that you shouldn't have a `node` object to
# access, because it's not going to produce a very efficient # access, because it's not going to produce a very efficient
# Cython function!) # Cython function!)
outs[idx][0] = node.outputs[idx].type.value_zeros(0) outer_outputs[idx][0] = node.outputs[idx].type.value_zeros(0)
else: else:
outs[idx][0] = None outer_outputs[idx][0] = None
return return
for idx in range(n_outs + n_nit_sot): for idx in range(n_outs + n_nit_sot):
pos[idx] = -mintaps[idx] % store_steps[idx] pos[idx] = -mintaps[idx] % store_steps[idx]
offset = nit_sot_arg_offset + n_nit_sot offset = nit_sot_arg_offset + n_nit_sot
other_args = args[offset:] other_args = outer_inputs[offset:]
input_storage = fnct.input_storage
nb_mitmot_in = 0 nb_mitmot_in = 0
for idx in range(n_mit_mot): for idx in range(n_mit_mot):
nb_mitmot_in += tap_array_len[idx] nb_mitmot_in += tap_array_len[idx]
old_mitmot_input_storage = [None] * nb_mitmot_in old_mitmot_input_storage = [None] * nb_mitmot_in
old_mitmot_input_data = [None] * nb_mitmot_in old_mitmot_input_data = [None] * nb_mitmot_in
output_storage = fnct.output_storage
old_output_storage = [None] * len_output_storage old_output_storage = [None] * len_output_storage
old_output_data = [None] * len_output_storage old_output_data = [None] * len_output_storage
offset = n_seqs offset = n_seqs
...@@ -281,7 +274,9 @@ def perform( ...@@ -281,7 +274,9 @@ def perform(
offset += n_shared_outs offset += n_shared_outs
for idx in range(len(other_args)): for idx in range(len(other_args)):
input_storage[<unsigned int>(idx+offset)].storage[0] = other_args[idx] inner_input_storage[<unsigned int>(idx+offset)][0] = other_args[idx]
fn = fnct.fn
i = 0 i = 0
cond = 1 cond = 1
...@@ -292,11 +287,11 @@ def perform( ...@@ -292,11 +287,11 @@ def perform(
# 3. collect input slices # 3. collect input slices
for idx in range(n_seqs): for idx in range(n_seqs):
if vector_seqs[idx] == 1: if vector_seqs[idx] == 1:
input_storage[idx].storage[0] = args[\ inner_input_storage[idx][0] = outer_inputs[\
<unsigned int>(1+idx)][i:<unsigned int>(i+1)].reshape(()) <unsigned int>(1+idx)][i:<unsigned int>(i+1)].reshape(())
else: else:
input_storage[idx].storage[0] = \ inner_input_storage[idx][0] = \
args[<unsigned int>(idx+1)][i] outer_inputs[<unsigned int>(idx+1)][i]
offset = n_seqs offset = n_seqs
for idx in range(n_outs): for idx in range(n_outs):
...@@ -304,14 +299,14 @@ def perform( ...@@ -304,14 +299,14 @@ def perform(
for tdx in range(tap_array_len[idx]): for tdx in range(tap_array_len[idx]):
tap = tap_array[idx,tdx] tap = tap_array[idx,tdx]
_idx = (pos[idx]+tap)%store_steps[idx] _idx = (pos[idx]+tap)%store_steps[idx]
input_storage[offset].storage[0] =\ inner_input_storage[offset][0] =\
outs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(()) outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(())
offset += 1 offset += 1
else: else:
for tdx in range(tap_array_len[idx]): for tdx in range(tap_array_len[idx]):
tap = tap_array[idx,tdx] tap = tap_array[idx,tdx]
_idx = (pos[idx]+tap)%store_steps[idx] _idx = (pos[idx]+tap)%store_steps[idx]
input_storage[offset].storage[0] = outs[idx][0][_idx] inner_input_storage[offset][0] = outer_outputs[idx][0][_idx]
offset += 1 offset += 1
...@@ -319,11 +314,11 @@ def perform( ...@@ -319,11 +314,11 @@ def perform(
o_offset = n_outs + n_nit_sot o_offset = n_outs + n_nit_sot
if i == 0: if i == 0:
for j in range(n_shared_outs): for j in range(n_shared_outs):
input_storage[offset].storage[0] = args[<unsigned int>(a_offset+j)] inner_input_storage[offset][0] = outer_inputs[<unsigned int>(a_offset+j)]
offset += 1 offset += 1
else: else:
for j in range(n_shared_outs): for j in range(n_shared_outs):
input_storage[offset].storage[0] = outs[<unsigned int>(o_offset+j)][0] inner_input_storage[offset][0] = outer_outputs[<unsigned int>(o_offset+j)][0]
offset += 1 offset += 1
# 4. collecting slices where the output should be stored # 4. collecting slices where the output should be stored
...@@ -332,7 +327,7 @@ def perform( ...@@ -332,7 +327,7 @@ def perform(
offset = 0 offset = 0
for idx in range(n_mit_mot_outs): for idx in range(n_mit_mot_outs):
if not mitmots_preallocated[<unsigned int>idx]: if not mitmots_preallocated[<unsigned int>idx]:
output_storage[<unsigned int>offset].storage[0] = None inner_output_storage[<unsigned int>offset][0] = None
offset += 1 offset += 1
# 4.2. Collect slices for mitsots, sitsots and nitsots # 4.2. Collect slices for mitsots, sitsots and nitsots
...@@ -340,34 +335,34 @@ def perform( ...@@ -340,34 +335,34 @@ def perform(
for idx in range(n_outs + n_nit_sot - n_mit_mot): for idx in range(n_outs + n_nit_sot - n_mit_mot):
if ( store_steps[<unsigned int>(idx+n_mit_mot)] == 1 or if ( store_steps[<unsigned int>(idx+n_mit_mot)] == 1 or
vector_outs[<unsigned int>(idx+n_mit_mot)] == 1): vector_outs[<unsigned int>(idx+n_mit_mot)] == 1):
output_storage[<unsigned int>(idx+offset)].storage[0] = None inner_output_storage[<unsigned int>(idx+offset)][0] = None
else: else:
output_storage[<unsigned int>(idx+offset)].storage[0] =\ inner_output_storage[<unsigned int>(idx+offset)][0] =\
outs[<unsigned int>(idx+n_mit_mot)][0][pos[\ outer_outputs[<unsigned int>(idx+n_mit_mot)][0][pos[\
<unsigned int>(idx+n_mit_mot)]] <unsigned int>(idx+n_mit_mot)]]
else: else:
for idx in range(n_outs + n_nit_sot - n_mit_mot): for idx in range(n_outs + n_nit_sot - n_mit_mot):
output_storage[<unsigned int>(idx+offset)].storage[0] = None inner_output_storage[<unsigned int>(idx+offset)][0] = None
# 4.3. Collect slices for shared outputs # 4.3. Collect slices for shared outputs
offset += n_outs+n_nit_sot - n_mit_mot offset += n_outs+n_nit_sot - n_mit_mot
for idx in range(n_shared_outs): for idx in range(n_shared_outs):
output_storage[<unsigned int>(idx+offset)].storage[0] = None inner_output_storage[<unsigned int>(idx+offset)][0] = None
# 4.4. If there is a condition add it to the mix # 4.4. If there is a condition add it to the mix
if as_while: if as_while:
pdx = offset + n_shared_outs pdx = offset + n_shared_outs
output_storage[<unsigned int>pdx].storage[0] = None inner_output_storage[<unsigned int>pdx][0] = None
# 4.5. Keep a reference to the variables (ndarrays, GpuArrays, # 4.5. Keep a reference to the variables (ndarrays, GpuArrays,
# etc) currently in the output_storage to be able to compare them # etc) currently in the inner_output_storage to be able to compare them
# with the actual outputs of the inner function after its # with the actual outputs of the inner function after its
# execution. Also keep pointers to their data to be able to detect # execution. Also keep pointers to their data to be able to detect
# cases where outputs reused the allocated object but alter the # cases where outputs reused the allocated object but alter the
# memory region they refer to. # memory region they refer to.
for idx in range(len_output_storage): for idx in range(len_output_storage):
var = output_storage[idx].storage[0] var = inner_output_storage[idx][0]
old_output_storage[idx] = var old_output_storage[idx] = var
if var is None: if var is None:
...@@ -378,13 +373,13 @@ def perform( ...@@ -378,13 +373,13 @@ def perform(
old_output_data[idx] = var.gpudata old_output_data[idx] = var.gpudata
# 4.6. Keep a reference to the variables (ndarrays, GpuArrays, # 4.6. Keep a reference to the variables (ndarrays, GpuArrays,
# etc) associated with mitmot inputs currently in the input_storage to # etc) associated with mitmot inputs currently in the inner_input_storage to
# be able to compare them with the content of the input_storage after # be able to compare them with the content of the inner_input_storage after
# the execution of the function. Also keep pointers to their data to # the execution of the function. Also keep pointers to their data to
# be able to detect cases where outputs reused the allocated object # be able to detect cases where outputs reused the allocated object
# but alter the memory region they refer to. # but alter the memory region they refer to.
for idx in xrange(nb_mitmot_in): for idx in xrange(nb_mitmot_in):
var = input_storage[idx + n_seqs].storage[0] var = inner_input_storage[idx + n_seqs][0]
old_mitmot_input_storage[idx] = var old_mitmot_input_storage[idx] = var
if var is None: if var is None:
...@@ -423,18 +418,18 @@ def perform( ...@@ -423,18 +418,18 @@ def perform(
t_fn += dt_fn t_fn += dt_fn
if self.as_while: if self.as_while:
pdx = offset + n_shared_outs pdx = offset + n_shared_outs
cond = output_storage[pdx].storage[0] == 0 cond = inner_output_storage[pdx][0] == 0
# 5.2. By calling fn() directly instead of calling the aesara # 5.2. By calling fn() directly instead of calling the aesara
# function, it is possible that the updates have not been # function, it is possible that the updates have not been
# performed. Perform the updates if needed. # performed. Perform the updates if needed.
offset_out = len(output_storage) - 1 offset_out = len(inner_output_storage) - 1
if getattr(fn, 'need_update_inputs', True): if getattr(fn, 'need_update_inputs', True):
# Update the inputs that have an update function # Update the inputs that have an update function
for inp, storage in zip(self.fn.maker.expanded_inputs[::-1], for inp, storage in zip(self.fn.maker.expanded_inputs[::-1],
self.fn.input_storage[::-1]): self.fn.input_storage[::-1]):
if inp.update is not None: if inp.update is not None:
storage.data = output_storage[offset_out].data storage.data = inner_output_storage[offset_out][0].data
offset_out -= 1 offset_out -= 1
offset_out = 0 offset_out = 0
...@@ -452,7 +447,7 @@ def perform( ...@@ -452,7 +447,7 @@ def perform(
# Verify whether the input points to the same data as # Verify whether the input points to the same data as
# it did before the execution of the inner function. # it did before the execution of the inner function.
old_var = old_mitmot_input_storage[inp_idx] old_var = old_mitmot_input_storage[inp_idx]
new_var = input_storage[n_seqs + inp_idx].storage[0] new_var = inner_input_storage[n_seqs + inp_idx][0]
if old_var is new_var: if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx] old_data = old_mitmot_input_data[inp_idx]
if inps_is_tensor[n_seqs + inp_idx]: if inps_is_tensor[n_seqs + inp_idx]:
...@@ -466,14 +461,14 @@ def perform( ...@@ -466,14 +461,14 @@ def perform(
# recover the value as usual. Otherwise, the input was # recover the value as usual. Otherwise, the input was
# modified inplace and nothing needs to be done. # modified inplace and nothing needs to be done.
if not same_data: if not same_data:
outs[j][0][<unsigned int>(k + pos[j])] = \ outer_outputs[j][0][<unsigned int>(k + pos[j])] = \
input_storage[<unsigned int>(n_seqs + inp_idx)].storage[0] inner_input_storage[<unsigned int>(n_seqs + inp_idx)][0]
else: else:
# This output tap has not been preallocated, recover # This output tap has not been preallocated, recover
# its value as usual # its value as usual
outs[j][0][<unsigned int>(k + pos[j])] = \ outer_outputs[j][0][<unsigned int>(k + pos[j])] = \
output_storage[<unsigned int>offset_out].storage[0] inner_output_storage[<unsigned int>offset_out][0]
offset_out += 1 offset_out += 1
mitmot_out_idx += 1 mitmot_out_idx += 1
...@@ -487,15 +482,15 @@ def perform( ...@@ -487,15 +482,15 @@ def perform(
for j in range(begin, end): for j in range(begin, end):
# Copy the output value to `outs`, if necessary # Copy the output value to `outer_outputs`, if necessary
if store_steps[j] == 1 or vector_outs[j] == 1: if store_steps[j] == 1 or vector_outs[j] == 1:
outs[j][0][pos[j]] = output_storage[<unsigned int>(offset_out+j)].storage[0] outer_outputs[j][0][pos[j]] = inner_output_storage[<unsigned int>(offset_out+j)][0]
else: else:
# Check whether the initialization of the output storage map # Check whether the initialization of the output storage map
# for this output has been reused. # for this output has been reused.
old_var = old_output_storage[offset_out + j] old_var = old_output_storage[offset_out + j]
old_data = old_output_data[offset_out + j] old_data = old_output_data[offset_out + j]
new_var = output_storage[offset_out + j].storage[0] new_var = inner_output_storage[offset_out + j][0]
if old_var is new_var: if old_var is new_var:
if old_data is None: if old_data is None:
output_reused = False output_reused = False
...@@ -507,8 +502,8 @@ def perform( ...@@ -507,8 +502,8 @@ def perform(
output_reused = False output_reused = False
if not output_reused: if not output_reused:
outs[j][0][pos[j]] = \ outer_outputs[j][0][pos[j]] = \
output_storage[<unsigned int>(offset_out+j)].storage[0] inner_output_storage[<unsigned int>(offset_out+j)][0]
# 5.5 Copy over the values for nit_sot outputs # 5.5 Copy over the values for nit_sot outputs
...@@ -518,24 +513,24 @@ def perform( ...@@ -518,24 +513,24 @@ def perform(
if i == 0: if i == 0:
jout = j+offset_out jout = j+offset_out
shape = (store_steps[j],) + output_storage[jout].storage[0].shape shape = (store_steps[j],) + inner_output_storage[jout][0].shape
dtype = output_storage[jout].storage[0].dtype dtype = inner_output_storage[jout][0].dtype
if (outs[j][0] is None or if (outer_outputs[j][0] is None or
outs[j][0].shape[0] < store_steps[j] or outer_outputs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or outer_outputs[j][0].shape[1:] != shape[1:] or
outs[j][0].dtype != dtype ): outer_outputs[j][0].dtype != dtype ):
outs[j][0] = node.outputs[j].type.value_zeros(shape) outer_outputs[j][0] = node.outputs[j].type.value_zeros(shape)
elif outs[j][0].shape[0] != store_steps[j]: elif outer_outputs[j][0].shape[0] != store_steps[j]:
outs[j][0] = outs[j][0][:store_steps[j]] outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]]
outs[j][0][pos[j]] = output_storage[jout].storage[0] outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
elif store_steps[j] == 1 or vector_outs[j] == 1: elif store_steps[j] == 1 or vector_outs[j] == 1:
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0] outer_outputs[j][0][pos[j]] = inner_output_storage[j+offset_out][0]
else: else:
# Check whether the initialization of the output storage map # Check whether the initialization of the output storage map
# for this output has been reused. # for this output has been reused.
old_var = old_output_storage[offset_out + j] old_var = old_output_storage[offset_out + j]
old_data = old_output_data[offset_out + j] old_data = old_output_data[offset_out + j]
new_var = output_storage[offset_out + j].storage[0] new_var = inner_output_storage[offset_out + j][0]
if old_var is new_var: if old_var is new_var:
if old_data is None: if old_data is None:
output_reused = False output_reused = False
...@@ -548,7 +543,7 @@ def perform( ...@@ -548,7 +543,7 @@ def perform(
if not output_reused: if not output_reused:
try: try:
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0] outer_outputs[j][0][pos[j]] = inner_output_storage[j+offset_out][0]
except ValueError as e: except ValueError as e:
if i == 0: if i == 0:
raise raise
...@@ -564,7 +559,7 @@ def perform( ...@@ -564,7 +559,7 @@ def perform(
end += n_shared_outs end += n_shared_outs
for j in range(begin,end): for j in range(begin,end):
jout = j +offset_out jout = j +offset_out
outs[j][0] = output_storage[jout].storage[0] outer_outputs[j][0] = inner_output_storage[jout][0]
for idx in range(lenpos): for idx in range(lenpos):
pos[idx] = (pos[idx]+1)%store_steps[idx] pos[idx] = (pos[idx]+1)%store_steps[idx]
...@@ -585,24 +580,24 @@ def perform( ...@@ -585,24 +580,24 @@ def perform(
# are read and written. # are read and written.
# This way, there will be no information overwritten # This way, there will be no information overwritten
# before it is read (as it used to happen). # before it is read (as it used to happen).
shape = (pdx,)+ outs[idx][0].shape[1:] shape = (pdx,)+ outer_outputs[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape) tmp = node.outputs[idx].type.value_zeros(shape)
tmp[:] = outs[idx][0][:pdx] tmp[:] = outer_outputs[idx][0][:pdx]
outs[idx][0][:store_steps[idx]-pdx] = outs[idx][0][pdx:] outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:]
outs[idx][0][store_steps[idx]-pdx:] = tmp outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp
else: else:
shape = (store_steps[idx]-pdx,) + outs[idx][0].shape[1:] shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape) tmp = node.outputs[idx].type.value_zeros(shape)
tmp[:] = outs[idx][0][pdx:] tmp[:] = outer_outputs[idx][0][pdx:]
outs[idx][0][store_steps[idx]-pdx:] = outs[idx][0][:pdx] outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx]
outs[idx][0][:store_steps[idx]-pdx] = tmp outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp
# This would normally happen only when doing truncated # This would normally happen only when doing truncated
# backpropagation through time. In such a scenario Scan is # backpropagation through time. In such a scenario Scan is
# expected to return 0 for all entries for which the gradient is # expected to return 0 for all entries for which the gradient is
# not actually computed # not actually computed
elif store_steps[idx] > i - self.mintaps[idx]: elif store_steps[idx] > i - self.mintaps[idx]:
outs[idx][0][i-self.mintaps[idx]:] = 0 outer_outputs[idx][0][i-self.mintaps[idx]:] = 0
# This is a fix for a bug introduced by while. If you say # This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output # you want to loop up to a condition, you expect the output
...@@ -618,15 +613,15 @@ def perform( ...@@ -618,15 +613,15 @@ def perform(
# to do boundschecks). The directive is used to make the # to do boundschecks). The directive is used to make the
# code faster, so this workaround is better then removing # code faster, so this workaround is better then removing
# the directive. # the directive.
sh0 = outs[idx][0].shape[0] sh0 = outer_outputs[idx][0].shape[0]
outs[idx][0] = outs[idx][0][:sh0-(n_steps - i)] outer_outputs[idx][0] = outer_outputs[idx][0][:sh0-(n_steps - i)]
# We never reuse the input or output storage of the # We never reuse the input or output storage of the
# inner function so we clear it. # inner function so we clear it.
for i_s in input_storage: for s in inner_input_storage:
i_s.storage[0] = None s[0] = None
for o_s in output_storage: for s in inner_output_storage:
o_s.storage[0] = None s[0] = None
t_call = time.time() - t0_call t_call = time.time() - t0_call
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论