提交 174117f9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove use of list.index in scan_perform.pyx

上级 77c7836e
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -1546,17 +1546,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1546,17 +1546,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.uint32) cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.uint32)
cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.uint32) cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.uint32)
tap_array_len = np.array( tap_array = (
[ self.info.mit_mot_in_slices
len(x) + self.info.mit_sot_in_slices
for x in chain( + self.info.sit_sot_in_slices
self.info.mit_mot_in_slices,
self.info.mit_sot_in_slices,
self.info.sit_sot_in_slices,
)
],
dtype=np.uint32,
) )
tap_array_len = np.array([len(x) for x in tap_array], dtype=np.uint32)
cython_vector_seqs = np.asarray(self.vector_seqs, dtype=bool) cython_vector_seqs = np.asarray(self.vector_seqs, dtype=bool)
cython_vector_outs = np.asarray(self.vector_outs, dtype=bool) cython_vector_outs = np.asarray(self.vector_outs, dtype=bool)
...@@ -1585,6 +1580,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1585,6 +1580,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
[getattr(out, "ndim", 0) for out in node.outputs], dtype=np.uint32 [getattr(out, "ndim", 0) for out in node.outputs], dtype=np.uint32
) )
# The input index for each mit-mot output
mit_mot_out_to_tap_idx = ()
for j in range(self.info.n_mit_mot):
for k in self.info.mit_mot_out_slices[j]:
mit_mot_out_to_tap_idx += (tap_array[j].index(k),)
mit_mot_out_to_tap_idx = np.asarray(mit_mot_out_to_tap_idx, dtype=np.uint32)
from aesara.scan.utils import InnerFunctionError from aesara.scan.utils import InnerFunctionError
# TODO: Extract `Capsule` object and use that # TODO: Extract `Capsule` object and use that
...@@ -1621,6 +1623,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1621,6 +1623,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_vector_outs, cython_vector_outs,
self.info.mit_mot_out_slices, self.info.mit_mot_out_slices,
cython_mitmots_preallocated, cython_mitmots_preallocated,
mit_mot_out_to_tap_idx,
cython_outs_is_tensor, cython_outs_is_tensor,
inner_input_storage, inner_input_storage,
inner_output_storage, inner_output_storage,
......
...@@ -62,7 +62,7 @@ numpy.import_array() ...@@ -62,7 +62,7 @@ numpy.import_array()
def get_version(): def get_version():
return 0.324 return 0.325
@cython.cdivision(True) @cython.cdivision(True)
...@@ -91,6 +91,7 @@ def perform( ...@@ -91,6 +91,7 @@ def perform(
const numpy.npy_bool[:] vector_outs not None, const numpy.npy_bool[:] vector_outs not None,
tuple mit_mot_out_slices not None, tuple mit_mot_out_slices not None,
const numpy.npy_bool[:] mitmots_preallocated not None, const numpy.npy_bool[:] mitmots_preallocated not None,
const unsigned int [:] mit_mot_out_to_tap_idx not None,
const numpy.npy_bool[:] outs_is_tensor not None, const numpy.npy_bool[:] outs_is_tensor not None,
list inner_input_storage not None, list inner_input_storage not None,
list inner_output_storage not None, list inner_output_storage not None,
...@@ -426,14 +427,12 @@ def perform( ...@@ -426,14 +427,12 @@ def perform(
mitmot_out_idx = 0 mitmot_out_idx = 0
for j in range(n_mit_mot): for j in range(n_mit_mot):
tap_array_j = tap_array[j]
pos_j = pos[j] pos_j = pos[j]
outer_outputs_j_0 = outer_outputs[j][0] outer_outputs_j_0 = outer_outputs[j][0]
for k in mit_mot_out_slices[j]: for k in mit_mot_out_slices[j]:
if mitmots_preallocated[mitmot_out_idx]: if mitmots_preallocated[mitmot_out_idx]:
# This output tap has been preallocated. inp_idx = mitmot_inp_offset + mit_mot_out_to_tap_idx[mitmot_out_idx]
inp_idx = mitmot_inp_offset + tap_array_j.index(k)
inner_inp_idx = n_seqs + inp_idx inner_inp_idx = n_seqs + inp_idx
# Verify whether the input points to the same data as # Verify whether the input points to the same data as
......
...@@ -23,7 +23,7 @@ if not config.cxx: ...@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.324 # must match constant returned in function get_version() version = 0.325 # must match constant returned in function get_version()
need_reload = False need_reload = False
scan_perform: Optional[ModuleType] = None scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论