提交 bb999046 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use C version of time in scan_perform.pyx

上级 deca03cf
......@@ -46,15 +46,15 @@
ones (where applicable). All this information is described (more or less)
by describing the arguments of this function)
"""
import sys
from libc.time cimport time, time_t
import cython
import numpy
import numpy
cimport numpy
import copy
import time
import sys
from aesara.scan.utils import InnerFunctionError
......@@ -62,7 +62,7 @@ numpy.import_array()
def get_version():
return 0.319
return 0.320
# TODO: We need to get rid of the negative indexing performed with `pos` and `l`.
# @cython.wraparound(False)
......@@ -94,7 +94,7 @@ def perform(
tuple outer_output_dtypes not None,
tuple outer_output_ndims not None,
fn,
) -> (float, int):
) -> (time_t, int):
"""
Parameters
----------
......@@ -164,9 +164,9 @@ def perform(
"""
# 1. Unzip the number of steps and sequences. If number of steps is
# negative flip sequences around, and make n_steps positive
cdef float t_fn = 0
cdef float t0_fn
cdef float dt_fn
cdef time_t t_fn = 0
cdef time_t t0_fn
cdef time_t dt_fn
cdef unsigned int n_steps = outer_inputs[0].item()
cdef unsigned int n_outs = n_mit_mot + n_mit_sot + n_sit_sot
cdef unsigned int seqs_arg_offset = n_seqs + 1
......@@ -407,14 +407,14 @@ def perform(
old_mitmot_input_data[idx] = var.data
# 5.1 compute outputs
t0_fn = time.time()
t0_fn = time(NULL)
try:
fn()
except Exception as exc:
raise InnerFunctionError(exc, sys.exc_info()[2])
dt_fn = time.time() - t0_fn
dt_fn = time(NULL) - t0_fn
t_fn += dt_fn
if as_while:
pdx = offset + n_shared_outs
......
......@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.319 # must match constant returned in function get_version()
version = 0.320 # must match constant returned in function get_version()
need_reload = False
scan_perform: Optional[ModuleType] = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论