提交 42db0775 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert more arguments to Numpy arrays in scan_perform.pyx

上级 bb999046
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -1546,13 +1546,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32)
cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32)
tap_array_len = tuple(
tap_array_len = np.array(
[
len(x)
for x in chain(
self.info.mit_mot_in_slices,
self.info.mit_sot_in_slices,
self.info.sit_sot_in_slices,
)
],
dtype=np.uint32,
)
cython_vector_seqs = np.asarray(self.vector_seqs, dtype=bool)
......@@ -1575,10 +1578,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_output_storage = [s.storage for s in self.fn.output_storage]
outer_output_dtypes = tuple(
getattr(out, "dtype", None) for out in node.outputs
getattr(out, "dtype", object) for out in node.outputs
)
outer_output_ndims = tuple(
getattr(out, "ndim", None) for out in node.outputs
outer_output_ndims = np.array(
[getattr(out, "ndim", 0) for out in node.outputs], dtype=np.uint32
)
from aesara.scan.utils import InnerFunctionError
......
......@@ -62,7 +62,7 @@ numpy.import_array()
def get_version():
return 0.320
return 0.321
# TODO: We need to get rid of the negative indexing performed with `pos` and `l`.
# @cython.wraparound(False)
......@@ -80,7 +80,7 @@ def perform(
int[:] pos not None,
int[:] store_steps not None,
tuple tap_array not None,
tuple tap_array_len not None,
const unsigned int[:] tap_array_len not None,
const numpy.npy_bool[:] vector_seqs not None,
const numpy.npy_bool[:] vector_outs not None,
tuple mit_mot_out_slices not None,
......@@ -92,7 +92,7 @@ def perform(
list outer_inputs not None,
list outer_outputs not None,
tuple outer_output_dtypes not None,
tuple outer_output_ndims not None,
const unsigned int[:] outer_output_ndims not None,
fn,
) -> (time_t, int):
"""
......
......@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.320 # must match constant returned in function get_version()
version = 0.321 # must match constant returned in function get_version()
need_reload = False
scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论