提交 75b8b833 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove redundant indexing in scan_perform.pyx

上级 4af12379
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -58,9 +58,14 @@ import sys ...@@ -58,9 +58,14 @@ import sys
from aesara.scan.utils import InnerFunctionError from aesara.scan.utils import InnerFunctionError
numpy.import_array()
def get_version(): def get_version():
return 0.317 return 0.318
# TODO: We need to get rid of the negative indexing performed with `pos` and `l`.
# @cython.wraparound(False)
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
const unsigned int n_shared_outs, const unsigned int n_shared_outs,
...@@ -160,6 +165,8 @@ def perform( ...@@ -160,6 +165,8 @@ def perform(
# 1. Unzip the number of steps and sequences. If number of steps is # 1. Unzip the number of steps and sequences. If number of steps is
# negative flip sequences around, and make n_steps positive # negative flip sequences around, and make n_steps positive
cdef float t_fn = 0 cdef float t_fn = 0
cdef float t0_fn
cdef float dt_fn
cdef unsigned int n_steps = outer_inputs[0].item() cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int seqs_arg_offset = n_seqs + 1 cdef unsigned int seqs_arg_offset = n_seqs + 1
...@@ -189,6 +196,16 @@ def perform( ...@@ -189,6 +196,16 @@ def perform(
cdef unsigned int len_output_storage = (n_mit_mot_outs + n_mit_sot + cdef unsigned int len_output_storage = (n_mit_mot_outs + n_mit_sot +
n_sit_sot + n_nit_sot + n_sit_sot + n_nit_sot +
n_shared_outs) n_shared_outs)
cdef unsigned int mitmot_inp_offset
cdef unsigned int mitmot_out_idx
cdef unsigned int inp_idx
cdef unsigned int inner_inp_idx
cdef unsigned int store_steps_j
cdef unsigned int store_steps_idx
cdef int mintaps_idx
cdef unsigned int sh0
cdef long pos_j
cdef long pos_idx
if n_steps < 0: if n_steps < 0:
# History, in the past, this was used for backward # History, in the past, this was used for backward
...@@ -220,23 +237,29 @@ def perform( ...@@ -220,23 +237,29 @@ def perform(
# 2.1 Create storage space for outputs # 2.1 Create storage space for outputs
for idx in range(n_outs): for idx in range(n_outs):
outer_outputs_idx = outer_outputs[idx]
if destroy_map[idx] != 0: if destroy_map[idx] != 0:
# ^ Case 1. Outputs should be computed inplace of their # ^ Case 1. Outputs should be computed inplace of their
# initial state # initial state
outer_outputs[idx][0] = outer_inputs[ <unsigned int>(1+ n_seqs + idx)] outer_outputs_idx[0] = outer_inputs[ <unsigned int>(1+ n_seqs + idx)]
elif ( outer_outputs[idx][0] is not None and continue
outer_outputs[idx][0].shape[1:] == outer_inputs[<unsigned int>(1+ n_seqs + idx)].shape[1:]
and outer_outputs[idx][0].shape[0] >= store_steps[idx] ): outer_outputs_idx_0 = outer_outputs_idx[0]
if ( outer_outputs_idx_0 is not None and
outer_outputs_idx_0.shape[1:] == outer_inputs[<unsigned int>(1+ n_seqs + idx)].shape[1:]
and outer_outputs_idx_0.shape[0] >= store_steps[idx] ):
# Put in the values of the initial state # Put in the values of the initial state
outer_outputs[idx][0] = outer_outputs[idx][0][:store_steps[idx]] outer_outputs_idx[0] = outer_outputs_idx_0[:store_steps[idx]]
if idx > n_mit_mot: if idx > n_mit_mot:
# TODO FIXME: Do not use wrapped indexing!
l = - mintaps[idx] l = - mintaps[idx]
outer_outputs[idx][0][:l] = outer_inputs[<unsigned int>(seqs_arg_offset + outer_outputs_idx_0[:l] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)][:l]
idx)][:l]
else: else:
outer_outputs[idx][0][:] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)] outer_outputs_idx_0[:] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)]
else: else:
outer_outputs[idx][0] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)].copy() outer_outputs_idx[0] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)].copy()
if n_steps == 0: if n_steps == 0:
for idx in range(n_outs, n_outs + n_nit_sot): for idx in range(n_outs, n_outs + n_nit_sot):
...@@ -252,12 +275,13 @@ def perform( ...@@ -252,12 +275,13 @@ def perform(
return 0.0, 0 return 0.0, 0
for idx in range(lenpos): for idx in range(lenpos):
# TODO FIXME: Do not use wrapped indexing!
pos[idx] = -mintaps[idx] % store_steps[idx] pos[idx] = -mintaps[idx] % store_steps[idx]
offset = nit_sot_arg_offset + n_nit_sot offset = nit_sot_arg_offset + n_nit_sot
other_args = outer_inputs[offset:] other_args = outer_inputs[offset:]
nb_mitmot_in = 0 cdef unsigned int nb_mitmot_in = 0
for idx in range(n_mit_mot): for idx in range(n_mit_mot):
nb_mitmot_in += tap_array_len[idx] nb_mitmot_in += tap_array_len[idx]
...@@ -290,16 +314,20 @@ def perform( ...@@ -290,16 +314,20 @@ def perform(
offset = n_seqs offset = n_seqs
for idx in range(n_outs): for idx in range(n_outs):
pos_idx = pos[idx]
store_steps_idx = store_steps[idx]
outer_outputs_idx = outer_outputs[idx]
if vector_outs[idx] == 1: if vector_outs[idx] == 1:
for tap in tap_array[idx]: for tap in tap_array[idx]:
_idx = (pos[idx]+tap)%store_steps[idx] _idx = (pos_idx + tap) % store_steps_idx
inner_input_storage[offset][0] =\ inner_input_storage[offset][0] =\
outer_outputs[idx][0][_idx:<unsigned int>(_idx+1)].reshape(()) outer_outputs_idx[0][_idx:<unsigned int>(_idx + 1)].reshape(())
offset += 1 offset += 1
else: else:
for tap in tap_array[idx]: for tap in tap_array[idx]:
_idx = (pos[idx]+tap)%store_steps[idx] _idx = (pos_idx + tap) % store_steps_idx
inner_input_storage[offset][0] = outer_outputs[idx][0][_idx] inner_input_storage[offset][0] = outer_outputs_idx[0][_idx]
offset += 1 offset += 1
...@@ -384,7 +412,7 @@ def perform( ...@@ -384,7 +412,7 @@ def perform(
try: try:
fn() fn()
except Exception as exc: except Exception as exc:
raise InnerFunctionError(exc, sys.exc_info()[-1]) raise InnerFunctionError(exc, sys.exc_info()[2])
dt_fn = time.time() - t0_fn dt_fn = time.time() - t0_fn
t_fn += dt_fn t_fn += dt_fn
...@@ -398,34 +426,33 @@ def perform( ...@@ -398,34 +426,33 @@ def perform(
mitmot_inp_offset = 0 mitmot_inp_offset = 0
mitmot_out_idx = 0 mitmot_out_idx = 0
for j in range(n_mit_mot): for j in range(n_mit_mot):
tap_array_j = tap_array[j]
pos_j = pos[j]
outer_outputs_j_0 = outer_outputs[j][0]
for k in mit_mot_out_slices[j]: for k in mit_mot_out_slices[j]:
if mitmots_preallocated[<unsigned int>mitmot_out_idx]: if mitmots_preallocated[mitmot_out_idx]:
# This output tap has been preallocated. # This output tap has been preallocated.
inp_idx = (mitmot_inp_offset + tap_array[j].index(k)) inp_idx = mitmot_inp_offset + tap_array_j.index(k)
inner_inp_idx = n_seqs + inp_idx inner_inp_idx = n_seqs + inp_idx
# Verify whether the input points to the same data as # Verify whether the input points to the same data as
# it did before the execution of the inner function. # it did before the execution of the inner function.
old_var = old_mitmot_input_storage[inp_idx] old_var = old_mitmot_input_storage[inp_idx]
new_var = inner_input_storage[inner_inp_idx][0] new_var = inner_input_storage[inner_inp_idx][0]
if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx]
same_data = (new_var.data == old_data)
else:
same_data = False
# If the corresponding input storage has been replaced, # If the corresponding input storage has been replaced,
# recover the value as usual. Otherwise, the input was # recover the value as usual. Otherwise, the input was
# modified inplace and nothing needs to be done. # modified inplace and nothing needs to be done.
if not same_data: if old_var is not new_var or old_mitmot_input_data[inp_idx] != new_var.data:
outer_outputs[j][0][<unsigned int>(k + pos[j])] = \ outer_outputs_j_0[<unsigned int>(k + pos_j)] = \
inner_input_storage[<unsigned int>(inner_inp_idx)][0] inner_input_storage[inner_inp_idx][0]
else: else:
# This output tap has not been preallocated, recover # This output tap has not been preallocated, recover
# its value as usual # its value as usual
outer_outputs[j][0][<unsigned int>(k + pos[j])] = \ outer_outputs_j_0[<unsigned int>(k + pos_j)] = \
inner_output_storage[<unsigned int>offset_out][0] inner_output_storage[offset_out][0]
offset_out += 1 offset_out += 1
mitmot_out_idx += 1 mitmot_out_idx += 1
...@@ -439,72 +466,63 @@ def perform( ...@@ -439,72 +466,63 @@ def perform(
for j in range(begin, end): for j in range(begin, end):
jout = j + offset_out
outer_outputs_j = outer_outputs[j]
# Copy the output value to `outer_outputs`, if necessary # Copy the output value to `outer_outputs`, if necessary
if store_steps[j] == 1 or vector_outs[j] == 1: if store_steps[j] == 1 or vector_outs[j] == 1:
outer_outputs[j][0][pos[j]] = inner_output_storage[<unsigned int>(offset_out+j)][0] outer_outputs_j[0][pos[j]] = inner_output_storage[jout][0]
else: else:
# Check whether the initialization of the output storage map # Check whether the initialization of the output storage map
# for this output has been reused. # for this output has been reused.
old_var = old_output_storage[offset_out + j] old_var = old_output_storage[jout]
old_data = old_output_data[offset_out + j] old_data = old_output_data[jout]
new_var = inner_output_storage[offset_out + j][0] new_var = inner_output_storage[jout][0]
if old_var is new_var:
if old_data is None:
output_reused = False
else:
output_reused = (new_var.data == old_data)
else:
output_reused = False
if not output_reused:
outer_outputs[j][0][pos[j]] = \
inner_output_storage[<unsigned int>(offset_out+j)][0]
if old_var is not new_var or old_data is None:
outer_outputs_j[0][pos[j]] = new_var
# 5.5 Copy over the values for nit_sot outputs # 5.5 Copy over the values for nit_sot outputs
begin = end begin = end
end += n_nit_sot end += n_nit_sot
for j in range(begin,end): for j in range(begin,end):
jout = j + offset_out
if i == 0: if i == 0:
jout = j+offset_out store_steps_j = store_steps[j]
shape = (store_steps[j],) + inner_output_storage[jout][0].shape inner_output_storage_jout_0 = inner_output_storage[jout][0]
dtype = inner_output_storage[jout][0].dtype shape = (store_steps_j,) + inner_output_storage_jout_0.shape
if (outer_outputs[j][0] is None or dtype = inner_output_storage_jout_0.dtype
outer_outputs[j][0].shape[0] < store_steps[j] or outer_outputs_j = outer_outputs[j]
outer_outputs[j][0].shape[1:] != shape[1:] or outer_outputs_j_0 = outer_outputs_j[0]
outer_outputs[j][0].dtype != dtype ):
outer_outputs[j][0] = numpy.empty(shape, dtype=outer_output_dtypes[j]) if (
elif outer_outputs[j][0].shape[0] != store_steps[j]: outer_outputs_j_0 is None or
outer_outputs[j][0] = outer_outputs[j][0][:store_steps[j]] outer_outputs_j_0.shape[0] < store_steps_j or
outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0] outer_outputs_j_0.shape[1:] != shape[1:] or
outer_outputs_j_0.dtype != dtype
):
new_outer_outputs_j_0 = numpy.empty(shape, dtype=outer_output_dtypes[j])
elif outer_outputs_j_0.shape[0] != store_steps_j:
new_outer_outputs_j_0 = outer_outputs_j_0[:store_steps_j]
else:
new_outer_outputs_j_0 = outer_outputs_j_0
new_outer_outputs_j_0[pos[j]] = inner_output_storage_jout_0
outer_outputs_j[0] = new_outer_outputs_j_0
elif store_steps[j] == 1 or vector_outs[j] == 1: elif store_steps[j] == 1 or vector_outs[j] == 1:
outer_outputs[j][0][pos[j]] = inner_output_storage[j+offset_out][0] outer_outputs[j][0][pos[j]] = inner_output_storage[jout][0]
else: else:
# Check whether the initialization of the output storage map # Check whether the initialization of the output storage map
# for this output has been reused. # for this output has been reused.
old_var = old_output_storage[offset_out + j] old_var = old_output_storage[jout]
old_data = old_output_data[offset_out + j] old_data = old_output_data[jout]
new_var = inner_output_storage[offset_out + j][0] new_var = inner_output_storage[jout][0]
if old_var is new_var:
if old_data is None: if old_var is not new_var or old_data is None:
output_reused = False outer_outputs[j][0][pos[j]] = new_var
else:
output_reused = (new_var.data == old_data)
else:
output_reused = False
if not output_reused:
try:
outer_outputs[j][0][pos[j]] = inner_output_storage[j+offset_out][0]
except ValueError as e:
if i == 0:
raise
raise ValueError(
"An output of the Scan has changed shape. "
"This may be caused by a push-out optimization."
" Try adding 'optimizer_excluding=scan_pushout'"
" to your Aesara flags.")
# 5.6 Copy over the values for outputs corresponding to shared # 5.6 Copy over the values for outputs corresponding to shared
# variables # variables
...@@ -522,35 +540,40 @@ def perform( ...@@ -522,35 +540,40 @@ def perform(
begin = n_mit_mot begin = n_mit_mot
end = n_outs + n_nit_sot end = n_outs + n_nit_sot
for idx in range(begin, end): for idx in range(begin, end):
if ( store_steps[idx] < i-mintaps[idx] and outer_outputs_idx = outer_outputs[idx]
pos[idx] < store_steps[idx] ): outer_outputs_idx_0 = outer_outputs_idx[0]
store_steps_idx = store_steps[idx]
mintaps_idx = mintaps[idx]
pdx = pos[idx]
pdx = pos[idx] if (store_steps_idx < i - mintaps_idx and pdx < store_steps_idx ):
if pdx >= store_steps[idx]//2 : if pdx >= store_steps_idx // 2 :
# It seems inefficient to copy the bigger part of the # It seems inefficient to copy the bigger part of the
# array over, and back, but it is the only way that # array over, and back, but it is the only way that
# there is no overlap in the areas of out[idx][0] that # there is no overlap in the areas of out[idx][0] that
# are read and written. # are read and written.
# This way, there will be no information overwritten # This way, there will be no information overwritten
# before it is read (as it used to happen). # before it is read (as it used to happen).
shape = (pdx,)+ outer_outputs[idx][0].shape[1:] shape = (pdx,)+ outer_outputs_idx_0.shape[1:]
tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx]) tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx])
tmp[:] = outer_outputs[idx][0][:pdx] tmp[:] = outer_outputs_idx_0[:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = outer_outputs[idx][0][pdx:] outer_outputs_idx_0[:store_steps_idx - pdx] = outer_outputs_idx_0[pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = tmp outer_outputs_idx_0[store_steps_idx - pdx:] = tmp
else: else:
shape = (store_steps[idx]-pdx,) + outer_outputs[idx][0].shape[1:] shape = (store_steps_idx - pdx,) + outer_outputs_idx_0.shape[1:]
tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx]) tmp = numpy.empty(shape, dtype=outer_output_dtypes[idx])
tmp[:] = outer_outputs[idx][0][pdx:] tmp[:] = outer_outputs_idx_0[pdx:]
outer_outputs[idx][0][store_steps[idx]-pdx:] = outer_outputs[idx][0][:pdx] outer_outputs_idx_0[store_steps_idx - pdx:] = outer_outputs_idx_0[:pdx]
outer_outputs[idx][0][:store_steps[idx]-pdx] = tmp outer_outputs_idx_0[:store_steps_idx - pdx] = tmp
# This would normally happen only when doing truncated # This would normally happen only when doing truncated
# backpropagation through time. In such a scenario Scan is # backpropagation through time. In such a scenario Scan is
# expected to return 0 for all entries for which the gradient is # expected to return 0 for all entries for which the gradient is
# not actually computed # not actually computed
elif store_steps[idx] > i - mintaps[idx]: elif store_steps_idx > i - mintaps_idx:
outer_outputs[idx][0][i - mintaps[idx]:] = 0 outer_outputs_idx_0[i - mintaps_idx:] = 0
# This is a fix for a bug introduced by while. If you say # This is a fix for a bug introduced by while. If you say
# you want to loop up to a condition, you expect the output # you want to loop up to a condition, you expect the output
...@@ -566,8 +589,8 @@ def perform( ...@@ -566,8 +589,8 @@ def perform(
# to do boundschecks). The directive is used to make the # to do boundschecks). The directive is used to make the
# code faster, so this workaround is better then removing # code faster, so this workaround is better then removing
# the directive. # the directive.
sh0 = outer_outputs[idx][0].shape[0] sh0 = outer_outputs_idx_0.shape[0]
outer_outputs[idx][0] = outer_outputs[idx][0][:sh0-(n_steps - i)] outer_outputs_idx[0] = outer_outputs_idx_0[:sh0-(n_steps - i)]
# We never reuse the input or output storage of the # We never reuse the input or output storage of the
# inner function so we clear it. # inner function so we clear it.
......
...@@ -23,7 +23,7 @@ if not config.cxx: ...@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.317 # must match constant returned in function get_version() version = 0.318 # must match constant returned in function get_version()
need_reload = False need_reload = False
scan_perform: Optional[ModuleType] = None scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论