提交 104dc037 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix the type of t_fn in scan_perform.pyx

上级 e2e23668
...@@ -59,7 +59,7 @@ from aesara.scan.utils import InnerFunctionError ...@@ -59,7 +59,7 @@ from aesara.scan.utils import InnerFunctionError
def get_version(): def get_version():
return 0.314 return 0.315
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -157,7 +157,7 @@ def perform( ...@@ -157,7 +157,7 @@ def perform(
""" """
# 1. Unzip the number of steps and sequences. If number of steps is # 1. Unzip the number of steps and sequences. If number of steps is
# negative flip sequences around, and make n_steps positive # negative flip sequences around, and make n_steps positive
cdef unsigned int t_fn = 0 cdef float t_fn = 0
cdef unsigned int n_steps = outer_inputs[0].item() cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int seqs_arg_offset = n_seqs + 1 cdef unsigned int seqs_arg_offset = n_seqs + 1
......
...@@ -23,7 +23,7 @@ if not config.cxx: ...@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.314 # must match constant returned in function get_version() version = 0.315 # must match constant returned in function get_version()
need_reload = False need_reload = False
scan_perform: Optional[ModuleType] = None scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论