提交 2bf1676e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use cdivision in scan_perform.pyx

上级 42db0775
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -1543,8 +1543,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1543,8 +1543,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_mintaps = np.asarray(self.mintaps, dtype="int32") cython_mintaps = np.asarray(self.mintaps, dtype="int32")
n_outs = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot n_outs = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32) cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.uint32)
cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32) cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.uint32)
tap_array_len = np.array( tap_array_len = np.array(
[ [
......
...@@ -62,10 +62,17 @@ numpy.import_array() ...@@ -62,10 +62,17 @@ numpy.import_array()
def get_version(): def get_version():
return 0.321 return 0.322
@cython.cdivision(True)
cdef inline unsigned int pymod(int a, unsigned int b):
return (a % (<int>b) + <int>b) % b
# TODO: We need to get rid of the negative indexing performed with `pos` and `l`. # TODO: We need to get rid of the negative indexing performed with `pos` and `l`.
# @cython.wraparound(False) # @cython.wraparound(False)
@cython.cdivision(True)
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
const unsigned int n_shared_outs, const unsigned int n_shared_outs,
...@@ -77,8 +84,8 @@ def perform( ...@@ -77,8 +84,8 @@ def perform(
const unsigned int n_nit_sot, const unsigned int n_nit_sot,
const bint as_while, const bint as_while,
const int[:] mintaps not None, const int[:] mintaps not None,
int[:] pos not None, unsigned int[:] pos not None,
int[:] store_steps not None, unsigned int[:] store_steps not None,
tuple tap_array not None, tuple tap_array not None,
const unsigned int[:] tap_array_len not None, const unsigned int[:] tap_array_len not None,
const numpy.npy_bool[:] vector_seqs not None, const numpy.npy_bool[:] vector_seqs not None,
...@@ -204,8 +211,8 @@ def perform( ...@@ -204,8 +211,8 @@ def perform(
cdef unsigned int store_steps_idx cdef unsigned int store_steps_idx
cdef int mintaps_idx cdef int mintaps_idx
cdef unsigned int sh0 cdef unsigned int sh0
cdef long pos_j cdef unsigned int pos_j
cdef long pos_idx cdef unsigned int pos_idx
if n_steps < 0: if n_steps < 0:
# History, in the past, this was used for backward # History, in the past, this was used for backward
...@@ -276,7 +283,7 @@ def perform( ...@@ -276,7 +283,7 @@ def perform(
for idx in range(lenpos): for idx in range(lenpos):
# TODO FIXME: Do not use wrapped indexing! # TODO FIXME: Do not use wrapped indexing!
pos[idx] = -mintaps[idx] % store_steps[idx] pos[idx] = pymod(-mintaps[idx], store_steps[idx])
offset = nit_sot_arg_offset + n_nit_sot offset = nit_sot_arg_offset + n_nit_sot
other_args = outer_inputs[offset:] other_args = outer_inputs[offset:]
...@@ -320,13 +327,13 @@ def perform( ...@@ -320,13 +327,13 @@ def perform(
if vector_outs[idx] == 1: if vector_outs[idx] == 1:
for tap in tap_array[idx]: for tap in tap_array[idx]:
_idx = (pos_idx + tap) % store_steps_idx _idx = pymod(pos_idx + tap, store_steps_idx)
inner_input_storage[offset][0] =\ inner_input_storage[offset][0] =\
outer_outputs_idx[0][_idx:<unsigned int>(_idx + 1)].reshape(()) outer_outputs_idx[0][_idx:<unsigned int>(_idx + 1)].reshape(())
offset += 1 offset += 1
else: else:
for tap in tap_array[idx]: for tap in tap_array[idx]:
_idx = (pos_idx + tap) % store_steps_idx _idx = pymod(pos_idx + tap, store_steps_idx)
inner_input_storage[offset][0] = outer_outputs_idx[0][_idx] inner_input_storage[offset][0] = outer_outputs_idx[0][_idx]
offset += 1 offset += 1
...@@ -533,7 +540,8 @@ def perform( ...@@ -533,7 +540,8 @@ def perform(
outer_outputs[j][0] = inner_output_storage[jout][0] outer_outputs[j][0] = inner_output_storage[jout][0]
for idx in range(lenpos): for idx in range(lenpos):
pos[idx] = (pos[idx]+1)%store_steps[idx] pos[idx] = pymod(pos[idx] + 1, store_steps[idx])
i = i + 1 i = i + 1
# 6. Check if you need to re-order output buffers # 6. Check if you need to re-order output buffers
......
...@@ -23,7 +23,7 @@ if not config.cxx: ...@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.321 # must match constant returned in function get_version() version = 0.322 # must match constant returned in function get_version()
need_reload = False need_reload = False
scan_perform: Optional[ModuleType] = None scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论