提交 7890a34f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refine Scan's Cython implementation type information and pre-allocations

上级 74fb5433
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -1542,6 +1542,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1542,6 +1542,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_mintaps = np.asarray(self.mintaps, dtype="int32") cython_mintaps = np.asarray(self.mintaps, dtype="int32")
n_outs = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32)
cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32)
tap_array_len = tuple( tap_array_len = tuple(
len(x) len(x)
for x in chain( for x in chain(
...@@ -1551,22 +1555,21 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1551,22 +1555,21 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
) )
cython_vector_seqs = np.asarray(self.vector_seqs, dtype="int32") cython_vector_seqs = np.asarray(self.vector_seqs, dtype=bool)
cython_vector_outs = np.asarray(self.vector_outs, dtype="int32") cython_vector_outs = np.asarray(self.vector_outs, dtype=bool)
cython_mitmots_preallocated = np.asarray( cython_mitmots_preallocated = np.asarray(
self.mitmots_preallocated, dtype="int32" self.mitmots_preallocated, dtype=bool
) )
cython_outs_is_tensor = np.asarray(outs_is_tensor, dtype=bool)
cython_outs_is_tensor = np.asarray(outs_is_tensor, dtype="int32")
if self.destroy_map: if self.destroy_map:
cython_destroy_map = [ cython_destroy_map = [
x in self.destroy_map for x in range(len(node.outputs)) x in self.destroy_map for x in range(len(node.outputs))
] ]
else: else:
cython_destroy_map = [0 for x in range(len(node.outputs))] cython_destroy_map = [False for x in range(len(node.outputs))]
cython_destroy_map = np.asarray(cython_destroy_map, dtype="int32") cython_destroy_map = np.asarray(cython_destroy_map, dtype=bool)
inner_input_storage = [s.storage for s in self.fn.input_storage] inner_input_storage = [s.storage for s in self.fn.input_storage]
inner_output_storage = [s.storage for s in self.fn.output_storage] inner_output_storage = [s.storage for s in self.fn.output_storage]
...@@ -1604,6 +1607,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1604,6 +1607,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.info.n_nit_sot, self.info.n_nit_sot,
self.info.as_while, self.info.as_while,
cython_mintaps, cython_mintaps,
cython_pos,
cython_store_steps,
self.info.mit_mot_in_slices self.info.mit_mot_in_slices
+ self.info.mit_sot_in_slices + self.info.mit_sot_in_slices
+ self.info.sit_sot_in_slices, + self.info.sit_sot_in_slices,
......
...@@ -59,58 +59,64 @@ from aesara.scan.utils import InnerFunctionError ...@@ -59,58 +59,64 @@ from aesara.scan.utils import InnerFunctionError
def get_version(): def get_version():
return 0.315 return 0.316
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
unsigned int n_shared_outs, const unsigned int n_shared_outs,
unsigned int n_mit_mot_outs, const unsigned int n_mit_mot_outs,
unsigned int n_seqs, const unsigned int n_seqs,
unsigned int n_mit_mot, const unsigned int n_mit_mot,
unsigned int n_mit_sot, const unsigned int n_mit_sot,
unsigned int n_sit_sot, const unsigned int n_sit_sot,
unsigned int n_nit_sot, const unsigned int n_nit_sot,
bint as_while, const bint as_while,
numpy.ndarray[numpy.int32_t,ndim=1] mintaps, const int[:] mintaps,
tuple tap_array, int[:] pos,
tuple tap_array_len, int[:] store_steps,
numpy.ndarray[numpy.int32_t,ndim=1] vector_seqs, tuple tap_array,
numpy.ndarray[numpy.int32_t,ndim=1] vector_outs, tuple tap_array_len,
tuple mit_mot_out_slices, const numpy.npy_bool[:] vector_seqs,
numpy.ndarray[numpy.int32_t,ndim=1] mitmots_preallocated, const numpy.npy_bool[:] vector_outs,
numpy.ndarray[numpy.int32_t,ndim=1] outs_is_tensor, tuple mit_mot_out_slices,
list inner_input_storage, const numpy.npy_bool[:] mitmots_preallocated,
list inner_output_storage, const numpy.npy_bool[:] outs_is_tensor,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map, list inner_input_storage,
list outer_inputs, list inner_output_storage,
list outer_outputs, const numpy.npy_bool[:] destroy_map,
tuple outer_output_dtypes, list outer_inputs,
tuple outer_output_ndims, list outer_outputs,
fn, tuple outer_output_dtypes,
tuple outer_output_ndims,
fn,
) -> (float, int): ) -> (float, int):
""" """
Parameters Parameters
---------- ----------
n_shared_outs: unsigned int n_shared_outs
Number of arguments that correspond to shared variables with Number of arguments that correspond to shared variables with
updates updates
n_mit_mot_outs: unsigned int n_mit_mot_outs
Sum over the number of output taps for each mit_mot sequence Sum over the number of output taps for each mit_mot sequence
n_seqs: unsigned int n_seqs
Number of sequences provided as input Number of sequences provided as input
n_mit_mot : unsigned int n_mit_mot
Number of mit_mot arguments Number of mit_mot arguments
n_mit_sot: unsigned int n_mit_sot
Number of mit_sot arguments Number of mit_sot arguments
n_sit_sot: unsigned int n_sit_sot
Number of sit sot arguments Number of sit sot arguments
n_nit_sot: unsigned int n_nit_sot
Number of nit_sot arguments Number of nit_sot arguments
mintaps: int32 ndarray (can also be a simple python list if that is better !) mintaps
For any of the mit_mot, mit_sot, sit_sot says which is the furtherst For any of the mit_mot, mit_sot, sit_sot says which is the furtherst
away input tap from current position. For example, if the taps where [-2, away input tap from current position. For example, if the taps where [-2,
-5, -9], the mintap would be -9. For sit_sot this is always -1 since -5, -9], the mintap would be -9. For sit_sot this is always -1, since it
is the only allowed tap. is the only allowed tap.
pos
Storage for positions.
store_steps
The length of each output.
tap_array tap_array
For each of the mit_mot, mit_sot, sit_sot (the first dimension) says For each of the mit_mot, mit_sot, sit_sot (the first dimension) says
which are the corresponding input taps. While this is a matrix, not all which are the corresponding input taps. While this is a matrix, not all
...@@ -120,33 +126,29 @@ def perform( ...@@ -120,33 +126,29 @@ def perform(
tap_array_len tap_array_len
For each of the mit_mot, mit_sot, sit_sot says how many input taps For each of the mit_mot, mit_sot, sit_sot says how many input taps
each has. For sit_sot this will always be 1. each has. For sit_sot this will always be 1.
vector_seqs: int32 ndarray (can be replaced by a list of bools if better) vector_seqs
For each sequence the corresponding entry is either a 1, is the For each sequence the corresponding entry is either a 1, is the
sequence is a vector or 0 if it has more than 1 dimension sequence is a vector or 0 if it has more than 1 dimension
vector_outs: int32 ndarray( can be replaced by list of bools if better) vector_outs
For each output ( mit_mot, mit_sot, sit_sot, nit_sot in this order) For each output (i.e. mit_mot, mit_sot, sit_sot, nit_sot in this order)
the entry is 1 if the corresponding argument is a 1 dimensional the entry is 1 if the corresponding argument is a 1 dimensional
tensor, 0 otherwise. tensor, 0 otherwise.
mit_mot_out_slices mit_mot_out_slices
Same as tap_array, but for the output taps of mit_mot sequences Same as tap_array, but for the output taps of mit_mot sequences
outs_is_tensor : int32 ndarray (Can be replaced by a list) outs_is_tensor
Array of boolean indicating, for every output, whether it is a tensor Array of boolean indicating, for every output, whether it is a tensor
or not or not.
inner_input_storage inner_input_storage
The storage locations for the inner-function's inputs. The storage locations for the inner-function's inputs.
inner_output_storage inner_output_storage
The storage locations for the inner-function's outputs. The storage locations for the inner-function's outputs.
fnct: Function
The compiled Aesara inner-function object.
destroy_map destroy_map
Array of boolean saying if an output is computed inplace Array of boolean saying if an output is computed inplace
outer_inputs: list of ndarrays (and random states) outer_inputs
The inputs of scan in a given order ( n_steps, sequences, mit_mot, The inputs of scan in a given order ( n_steps, sequences, mit_mot,
mit_sot, sit_sot, nit_sot, shared_outs, other_args) mit_sot, sit_sot, nit_sot, shared_outs, other_args)
outer_outputs: list of 1 element list ( or storage objects?) outer_outputs
This is where we need to copy our outputs ( we don't return the This is where we need to copy the new outputs.
results, though we can change the code such that we return, and
figure things out on the outside - python)
outer_output_dtypes outer_output_dtypes
The dtypes for each outer output. The dtypes for each outer output.
outer_output_ndims outer_output_ndims
...@@ -167,14 +169,6 @@ def perform( ...@@ -167,14 +169,6 @@ def perform(
n_shared_outs) n_shared_outs)
cdef unsigned int offset_out cdef unsigned int offset_out
cdef unsigned int lenpos = n_outs + n_nit_sot cdef unsigned int lenpos = n_outs + n_nit_sot
# TODO: See how this is being converted and whether or not we can remove
# fixed allocations caused by this.
cdef int pos[500] # put a maximum of 500 outputs
cdef unsigned int len_store_steps = n_mit_mot + n_mit_sot + n_sit_sot + n_nit_sot
# The length of each output
# TODO: See how this is being converted and whether or not we can remove
# fixed allocations caused by this.
cdef int store_steps[500]
cdef unsigned int l cdef unsigned int l
cdef unsigned int offset cdef unsigned int offset
cdef int tap cdef int tap
...@@ -257,7 +251,7 @@ def perform( ...@@ -257,7 +251,7 @@ def perform(
outer_outputs[idx][0] = None outer_outputs[idx][0] = None
return 0.0, 0 return 0.0, 0
for idx in range(n_outs + n_nit_sot): for idx in range(lenpos):
pos[idx] = -mintaps[idx] % store_steps[idx] pos[idx] = -mintaps[idx] % store_steps[idx]
offset = nit_sot_arg_offset + n_nit_sot offset = nit_sot_arg_offset + n_nit_sot
......
...@@ -23,7 +23,7 @@ if not config.cxx: ...@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.315 # must match constant returned in function get_version() version = 0.316 # must match constant returned in function get_version()
need_reload = False need_reload = False
scan_perform: Optional[ModuleType] = None scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论