提交 2bf1676e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use cdivision in scan_perform.pyx

上级 42db0775
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -1543,8 +1543,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_mintaps = np.asarray(self.mintaps, dtype="int32")
n_outs = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32)
cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32)
cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.uint32)
cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.uint32)
tap_array_len = np.array(
[
......
......@@ -62,10 +62,17 @@ numpy.import_array()
def get_version():
return 0.321
return 0.322
@cython.cdivision(True)
cdef inline unsigned int pymod(int a, unsigned int b):
return (a % (<int>b) + <int>b) % b
# TODO: We need to get rid of the negative indexing performed with `pos` and `l`.
# @cython.wraparound(False)
@cython.cdivision(True)
@cython.boundscheck(False)
def perform(
const unsigned int n_shared_outs,
......@@ -77,8 +84,8 @@ def perform(
const unsigned int n_nit_sot,
const bint as_while,
const int[:] mintaps not None,
int[:] pos not None,
int[:] store_steps not None,
unsigned int[:] pos not None,
unsigned int[:] store_steps not None,
tuple tap_array not None,
const unsigned int[:] tap_array_len not None,
const numpy.npy_bool[:] vector_seqs not None,
......@@ -204,8 +211,8 @@ def perform(
cdef unsigned int store_steps_idx
cdef int mintaps_idx
cdef unsigned int sh0
cdef long pos_j
cdef long pos_idx
cdef unsigned int pos_j
cdef unsigned int pos_idx
if n_steps < 0:
# History, in the past, this was used for backward
......@@ -276,7 +283,7 @@ def perform(
for idx in range(lenpos):
# TODO FIXME: Do not use wrapped indexing!
pos[idx] = -mintaps[idx] % store_steps[idx]
pos[idx] = pymod(-mintaps[idx], store_steps[idx])
offset = nit_sot_arg_offset + n_nit_sot
other_args = outer_inputs[offset:]
......@@ -320,13 +327,13 @@ def perform(
if vector_outs[idx] == 1:
for tap in tap_array[idx]:
_idx = (pos_idx + tap) % store_steps_idx
_idx = pymod(pos_idx + tap, store_steps_idx)
inner_input_storage[offset][0] =\
outer_outputs_idx[0][_idx:<unsigned int>(_idx + 1)].reshape(())
offset += 1
else:
for tap in tap_array[idx]:
_idx = (pos_idx + tap) % store_steps_idx
_idx = pymod(pos_idx + tap, store_steps_idx)
inner_input_storage[offset][0] = outer_outputs_idx[0][_idx]
offset += 1
......@@ -533,7 +540,8 @@ def perform(
outer_outputs[j][0] = inner_output_storage[jout][0]
for idx in range(lenpos):
pos[idx] = (pos[idx]+1)%store_steps[idx]
pos[idx] = pymod(pos[idx] + 1, store_steps[idx])
i = i + 1
# 6. Check if you need to re-order output buffers
......
......@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.321 # must match constant returned in function get_version()
version = 0.322 # must match constant returned in function get_version()
need_reload = False
scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论