提交 7890a34f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refine Scan's Cython implementation type information and pre-allocations

上级 74fb5433
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -1542,6 +1542,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_mintaps = np.asarray(self.mintaps, dtype="int32")
n_outs = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32)
cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32)
tap_array_len = tuple(
len(x)
for x in chain(
......@@ -1551,22 +1555,21 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
)
cython_vector_seqs = np.asarray(self.vector_seqs, dtype="int32")
cython_vector_outs = np.asarray(self.vector_outs, dtype="int32")
cython_vector_seqs = np.asarray(self.vector_seqs, dtype=bool)
cython_vector_outs = np.asarray(self.vector_outs, dtype=bool)
cython_mitmots_preallocated = np.asarray(
self.mitmots_preallocated, dtype="int32"
self.mitmots_preallocated, dtype=bool
)
cython_outs_is_tensor = np.asarray(outs_is_tensor, dtype="int32")
cython_outs_is_tensor = np.asarray(outs_is_tensor, dtype=bool)
if self.destroy_map:
cython_destroy_map = [
x in self.destroy_map for x in range(len(node.outputs))
]
else:
cython_destroy_map = [0 for x in range(len(node.outputs))]
cython_destroy_map = [False for x in range(len(node.outputs))]
cython_destroy_map = np.asarray(cython_destroy_map, dtype="int32")
cython_destroy_map = np.asarray(cython_destroy_map, dtype=bool)
inner_input_storage = [s.storage for s in self.fn.input_storage]
inner_output_storage = [s.storage for s in self.fn.output_storage]
......@@ -1604,6 +1607,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.info.n_nit_sot,
self.info.as_while,
cython_mintaps,
cython_pos,
cython_store_steps,
self.info.mit_mot_in_slices
+ self.info.mit_sot_in_slices
+ self.info.sit_sot_in_slices,
......
......@@ -59,58 +59,64 @@ from aesara.scan.utils import InnerFunctionError
def get_version():
return 0.315
return 0.316
@cython.boundscheck(False)
def perform(
unsigned int n_shared_outs,
unsigned int n_mit_mot_outs,
unsigned int n_seqs,
unsigned int n_mit_mot,
unsigned int n_mit_sot,
unsigned int n_sit_sot,
unsigned int n_nit_sot,
bint as_while,
numpy.ndarray[numpy.int32_t,ndim=1] mintaps,
tuple tap_array,
tuple tap_array_len,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs,
tuple mit_mot_out_slices,
numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor,
list inner_input_storage,
list inner_output_storage,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs,
list outer_outputs,
tuple outer_output_dtypes,
tuple outer_output_ndims,
fn,
const unsigned int n_shared_outs,
const unsigned int n_mit_mot_outs,
const unsigned int n_seqs,
const unsigned int n_mit_mot,
const unsigned int n_mit_sot,
const unsigned int n_sit_sot,
const unsigned int n_nit_sot,
const bint as_while,
const int[:] mintaps,
int[:] pos,
int[:] store_steps,
tuple tap_array,
tuple tap_array_len,
const numpy.npy_bool[:] vector_seqs,
const numpy.npy_bool[:] vector_outs,
tuple mit_mot_out_slices,
const numpy.npy_bool[:] mitmots_preallocated,
const numpy.npy_bool[:] outs_is_tensor,
list inner_input_storage,
list inner_output_storage,
const numpy.npy_bool[:] destroy_map,
list outer_inputs,
list outer_outputs,
tuple outer_output_dtypes,
tuple outer_output_ndims,
fn,
) -> (float, int):
"""
Parameters
----------
n_shared_outs: unsigned int
n_shared_outs
Number of arguments that correspond to shared variables with
updates
n_mit_mot_outs: unsigned int
n_mit_mot_outs
Sum over the number of output taps for each mit_mot sequence
n_seqs: unsigned int
n_seqs
Number of sequences provided as input
n_mit_mot : unsigned int
n_mit_mot
Number of mit_mot arguments
n_mit_sot: unsigned int
n_mit_sot
Number of mit_sot arguments
n_sit_sot: unsigned int
n_sit_sot
Number of sit sot arguments
n_nit_sot: unsigned int
n_nit_sot
Number of nit_sot arguments
mintaps: int32 ndarray (can also be a simple python list if that is better !)
mintaps
For any of the mit_mot, mit_sot, sit_sot says which is the furtherst
away input tap from current position. For example, if the taps where [-2,
-5, -9], the mintap would be -9. For sit_sot this is always -1 since
-5, -9], the mintap would be -9. For sit_sot this is always -1, since it
is the only allowed tap.
pos
Storage for positions.
store_steps
The length of each output.
tap_array
For each of the mit_mot, mit_sot, sit_sot (the first dimension) says
which are the corresponding input taps. While this is a matrix, not all
......@@ -120,33 +126,29 @@ def perform(
tap_array_len
For each of the mit_mot, mit_sot, sit_sot says how many input taps
each has. For sit_sot this will always be 1.
vector_seqs: int32 ndarray (can be replaced by a list of bools if better)
vector_seqs
For each sequence the corresponding entry is either a 1, is the
sequence is a vector or 0 if it has more than 1 dimension
vector_outs: int32 ndarray( can be replaced by list of bools if better)
For each output ( mit_mot, mit_sot, sit_sot, nit_sot in this order)
vector_outs
For each output (i.e. mit_mot, mit_sot, sit_sot, nit_sot in this order)
the entry is 1 if the corresponding argument is a 1 dimensional
tensor, 0 otherwise.
mit_mot_out_slices
Same as tap_array, but for the output taps of mit_mot sequences
outs_is_tensor : int32 ndarray (Can be replaced by a list)
outs_is_tensor
Array of boolean indicating, for every output, whether it is a tensor
or not
or not.
inner_input_storage
The storage locations for the inner-function's inputs.
inner_output_storage
The storage locations for the inner-function's outputs.
fnct: Function
The compiled Aesara inner-function object.
destroy_map
Array of boolean saying if an output is computed inplace
outer_inputs: list of ndarrays (and random states)
outer_inputs
The inputs of scan in a given order ( n_steps, sequences, mit_mot,
mit_sot, sit_sot, nit_sot, shared_outs, other_args)
outer_outputs: list of 1 element list ( or storage objects?)
This is where we need to copy our outputs ( we don't return the
results, though we can change the code such that we return, and
figure things out on the outside - python)
outer_outputs
This is where we need to copy the new outputs.
outer_output_dtypes
The dtypes for each outer output.
outer_output_ndims
......@@ -167,14 +169,6 @@ def perform(
n_shared_outs)
cdef unsigned int offset_out
cdef unsigned int lenpos = n_outs + n_nit_sot
# TODO: See how this is being converted and whether or not we can remove
# fixed allocations caused by this.
cdef int pos[500] # put a maximum of 500 outputs
cdef unsigned int len_store_steps = n_mit_mot + n_mit_sot + n_sit_sot + n_nit_sot
# The length of each output
# TODO: See how this is being converted and whether or not we can remove
# fixed allocations caused by this.
cdef int store_steps[500]
cdef unsigned int l
cdef unsigned int offset
cdef int tap
......@@ -257,7 +251,7 @@ def perform(
outer_outputs[idx][0] = None
return 0.0, 0
for idx in range(n_outs + n_nit_sot):
for idx in range(lenpos):
pos[idx] = -mintaps[idx] % store_steps[idx]
offset = nit_sot_arg_offset + n_nit_sot
......
......@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.315 # must match constant returned in function get_version()
version = 0.316 # must match constant returned in function get_version()
need_reload = False
scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论