提交 42db0775 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert more arguments to Numpy arrays in scan_perform.pyx

上级 bb999046
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -1546,13 +1546,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1546,13 +1546,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32) cython_pos = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32)
cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32) cython_store_steps = np.zeros(n_outs + self.info.n_nit_sot, dtype=np.int32)
tap_array_len = tuple( tap_array_len = np.array(
[
len(x) len(x)
for x in chain( for x in chain(
self.info.mit_mot_in_slices, self.info.mit_mot_in_slices,
self.info.mit_sot_in_slices, self.info.mit_sot_in_slices,
self.info.sit_sot_in_slices, self.info.sit_sot_in_slices,
) )
],
dtype=np.uint32,
) )
cython_vector_seqs = np.asarray(self.vector_seqs, dtype=bool) cython_vector_seqs = np.asarray(self.vector_seqs, dtype=bool)
...@@ -1575,10 +1578,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1575,10 +1578,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_output_storage = [s.storage for s in self.fn.output_storage] inner_output_storage = [s.storage for s in self.fn.output_storage]
outer_output_dtypes = tuple( outer_output_dtypes = tuple(
getattr(out, "dtype", None) for out in node.outputs getattr(out, "dtype", object) for out in node.outputs
) )
outer_output_ndims = tuple(
getattr(out, "ndim", None) for out in node.outputs outer_output_ndims = np.array(
[getattr(out, "ndim", 0) for out in node.outputs], dtype=np.uint32
) )
from aesara.scan.utils import InnerFunctionError from aesara.scan.utils import InnerFunctionError
......
...@@ -62,7 +62,7 @@ numpy.import_array() ...@@ -62,7 +62,7 @@ numpy.import_array()
def get_version(): def get_version():
return 0.320 return 0.321
# TODO: We need to get rid of the negative indexing performed with `pos` and `l`. # TODO: We need to get rid of the negative indexing performed with `pos` and `l`.
# @cython.wraparound(False) # @cython.wraparound(False)
...@@ -80,7 +80,7 @@ def perform( ...@@ -80,7 +80,7 @@ def perform(
int[:] pos not None, int[:] pos not None,
int[:] store_steps not None, int[:] store_steps not None,
tuple tap_array not None, tuple tap_array not None,
tuple tap_array_len not None, const unsigned int[:] tap_array_len not None,
const numpy.npy_bool[:] vector_seqs not None, const numpy.npy_bool[:] vector_seqs not None,
const numpy.npy_bool[:] vector_outs not None, const numpy.npy_bool[:] vector_outs not None,
tuple mit_mot_out_slices not None, tuple mit_mot_out_slices not None,
...@@ -92,7 +92,7 @@ def perform( ...@@ -92,7 +92,7 @@ def perform(
list outer_inputs not None, list outer_inputs not None,
list outer_outputs not None, list outer_outputs not None,
tuple outer_output_dtypes not None, tuple outer_output_dtypes not None,
tuple outer_output_ndims not None, const unsigned int[:] outer_output_ndims not None,
fn, fn,
) -> (time_t, int): ) -> (time_t, int):
""" """
......
...@@ -23,7 +23,7 @@ if not config.cxx: ...@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.320 # must match constant returned in function get_version() version = 0.321 # must match constant returned in function get_version()
need_reload = False need_reload = False
scan_perform: Optional[ModuleType] = None scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论