提交 174117f9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove use of list.index in scan_perform.pyx

上级 77c7836e
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -1546,17 +1546,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.uint32)
cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.uint32)
tap_array_len = np.array(
[
len(x)
for x in chain(
self.info.mit_mot_in_slices,
self.info.mit_sot_in_slices,
self.info.sit_sot_in_slices,
)
],
dtype=np.uint32,
tap_array = (
self.info.mit_mot_in_slices
+ self.info.mit_sot_in_slices
+ self.info.sit_sot_in_slices
)
tap_array_len = np.array([len(x) for x in tap_array], dtype=np.uint32)
cython_vector_seqs = np.asarray(self.vector_seqs, dtype=bool)
cython_vector_outs = np.asarray(self.vector_outs, dtype=bool)
......@@ -1585,6 +1580,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
[getattr(out, "ndim", 0) for out in node.outputs], dtype=np.uint32
)
# The input index for each mit-mot output
mit_mot_out_to_tap_idx = ()
for j in range(self.info.n_mit_mot):
for k in self.info.mit_mot_out_slices[j]:
mit_mot_out_to_tap_idx += (tap_array[j].index(k),)
mit_mot_out_to_tap_idx = np.asarray(mit_mot_out_to_tap_idx, dtype=np.uint32)
from aesara.scan.utils import InnerFunctionError
# TODO: Extract `Capsule` object and use that
......@@ -1621,6 +1623,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_vector_outs,
self.info.mit_mot_out_slices,
cython_mitmots_preallocated,
mit_mot_out_to_tap_idx,
cython_outs_is_tensor,
inner_input_storage,
inner_output_storage,
......
......@@ -62,7 +62,7 @@ numpy.import_array()
def get_version():
return 0.324
return 0.325
@cython.cdivision(True)
......@@ -91,6 +91,7 @@ def perform(
const numpy.npy_bool[:] vector_outs not None,
tuple mit_mot_out_slices not None,
const numpy.npy_bool[:] mitmots_preallocated not None,
const unsigned int [:] mit_mot_out_to_tap_idx not None,
const numpy.npy_bool[:] outs_is_tensor not None,
list inner_input_storage not None,
list inner_output_storage not None,
......@@ -426,14 +427,12 @@ def perform(
mitmot_out_idx = 0
for j in range(n_mit_mot):
tap_array_j = tap_array[j]
pos_j = pos[j]
outer_outputs_j_0 = outer_outputs[j][0]
for k in mit_mot_out_slices[j]:
if mitmots_preallocated[mitmot_out_idx]:
# This output tap has been preallocated.
inp_idx = mitmot_inp_offset + tap_array_j.index(k)
inp_idx = mitmot_inp_offset + mit_mot_out_to_tap_idx[mitmot_out_idx]
inner_inp_idx = n_seqs + inp_idx
# Verify whether the input points to the same data as
......
......@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.324 # must match constant returned in function get_version()
version = 0.325 # must match constant returned in function get_version()
need_reload = False
scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论