提交 74f80840 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move error and profiler handling to the thunk in Scan's Cython implementation

上级 c997333d
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -43,7 +43,6 @@ relies on the following elements to work properly :
"""
import dataclasses
import logging
import time
......@@ -1401,39 +1400,67 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
getattr(out, "ndim", None) for out in node.outputs
)
from aesara.scan.utils import InnerFunctionError
# TODO: Extract `Capsule` object and use that
# c_thunk = getattr(self.fn.fn.thunks[0], "cthunk", None)
# if len(self.fn.fn.thunks) == 1 and c_thunk:
# thunk_capsule = c_thunk.cthunk
# # We need to perform the following after calling
# # the thunk function:
# # for o in node.outputs:
# # compute_map[o][0] = True
def p(node, inputs, outputs):
t0_call = time.perf_counter()
t_fn = scan_perform_ext.perform(
self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
self.as_while,
cython_mintaps,
self.tap_array,
tap_array_len,
cython_vector_seqs,
cython_vector_outs,
self.mit_mot_out_slices,
cython_mitmots_preallocated,
cython_inps_is_tensor,
cython_outs_is_tensor,
inner_input_storage,
inner_output_storage,
getattr(self.fn.fn, "need_update_inputs", True),
inner_input_needs_update,
self.fn,
cython_destroy_map,
inputs,
outputs,
outer_output_dtypes,
outer_output_ndims,
)
try:
t_fn, n_steps = scan_perform_ext.perform(
self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
self.as_while,
cython_mintaps,
self.tap_array,
tap_array_len,
cython_vector_seqs,
cython_vector_outs,
self.mit_mot_out_slices,
cython_mitmots_preallocated,
cython_inps_is_tensor,
cython_outs_is_tensor,
inner_input_storage,
inner_output_storage,
getattr(self.fn.fn, "need_update_inputs", True),
inner_input_needs_update,
cython_destroy_map,
inputs,
outputs,
outer_output_dtypes,
outer_output_ndims,
self.fn.fn,
)
except InnerFunctionError as exc:
exc_type = type(exc.args[0])
exc_value = exc.args[0]
exc_trace = exc.args[1]
if hasattr(self.fn.fn, "position_of_error") and hasattr(
self.fn.fn, "thunks"
):
raise_with_op(
self.fn.maker.fgraph,
self.fn.fn.nodes[self.fn.fn.position_of_error],
self.fn.fn.thunks[self.fn.fn.position_of_error],
exc_info=(exc_type, exc_value, exc_trace),
)
else:
raise exc_value.with_traceback(exc_trace)
t_call = time.perf_counter() - t0_call
......@@ -1442,7 +1469,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if type(profile) is not bool and profile:
profile.vm_call_time += t_fn
profile.callcount += 1
profile.nbsteps += outputs[0]
profile.nbsteps += n_steps
profile.call_time += t_call
if hasattr(self.fn.fn, "update_profile"):
self.fn.fn.update_profile(profile)
......
......@@ -53,12 +53,13 @@ cimport numpy
import copy
import time
import sys
from aesara.link.utils import raise_with_op
from aesara.scan.utils import InnerFunctionError
def get_version():
return 0.302
return 0.312
@cython.boundscheck(False)
def perform(
......@@ -83,13 +84,13 @@ def perform(
list inner_output_storage,
bint need_update_inputs,
tuple inner_input_needs_update,
fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs,
list outer_outputs,
tuple outer_output_dtypes,
tuple outer_output_ndims,
):
fn,
) -> (float, int):
"""
Parameters
----------
......@@ -160,6 +161,8 @@ def perform(
The dtypes for each outer output.
outer_output_ndims
The number of dimensions for each outer output.
fn
The inner function thunk.
"""
# 1. Unzip the number of steps and sequences. If number of steps is
......@@ -258,7 +261,7 @@ def perform(
outer_outputs[idx][0] = numpy.empty((0,) * outer_output_ndims[idx], dtype=outer_output_dtypes[idx])
else:
outer_outputs[idx][0] = None
return
return 0.0, 0
for idx in range(n_outs + n_nit_sot):
pos[idx] = -mintaps[idx] % store_steps[idx]
......@@ -282,8 +285,6 @@ def perform(
for idx in range(len(other_args)):
inner_input_storage[<unsigned int>(idx+offset)][0] = other_args[idx]
fn = fnct.fn
i = 0
cond = 1
############## THE MAIN LOOP #########################
......@@ -398,25 +399,8 @@ def perform(
try:
fn()
except Exception:
if hasattr(fn, 'position_of_error'):
# this is a new vm-provided function
# the C VM needs this because the exception manipulation
# done by raise_with_op is not implemented in C.
if hasattr(fn, 'thunks'):
# For the CVM
raise_with_op(fnct.maker.fgraph,
fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error])
else:
# For the c linker
# We don't have access from python to all the
# temps values So for now, we just don't print
# the extra shapes/strides info
raise_with_op(fnct.maker.fgraph, fn.nodes[fn.position_of_error])
else:
# old-style linkers raise their own exceptions
raise
except Exception as exc:
raise InnerFunctionError(exc, sys.exc_info()[-1])
dt_fn = time.time() - t0_fn
t_fn += dt_fn
......@@ -625,4 +609,4 @@ def perform(
for s in inner_output_storage:
s[0] = None
return t_fn
return t_fn, i
......@@ -21,7 +21,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.302 # must match constant returned in function get_version()
version = 0.312 # must match constant returned in function get_version()
need_reload = False
......
......@@ -35,6 +35,10 @@ if TYPE_CHECKING:
_logger = logging.getLogger("aesara.scan.utils")
class InnerFunctionError(Exception):
"""An exception indicating that an error occurred in `Scan`'s inner function."""
def safe_new(
x: Variable, tag: str = "", dtype: Optional[Union[str, np.dtype]] = None
) -> Variable:
......@@ -126,8 +130,8 @@ class until:
class ScanProfileStats(ProfileStats):
show_sum = False
callcount = 0.0
nbsteps = 0.0
callcount = 0
nbsteps = 0
call_time = 0.0
def __init__(self, atexit_print=True, name=None, **kwargs):
......
......@@ -4635,7 +4635,10 @@ class TestScan:
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_cvm_exception_handling():
@pytest.mark.parametrize(
"mode", [Mode(linker="c|py", optimizer=None), Mode(linker="cvm", optimizer=None)]
)
def test_cvm_exception_handling(mode):
class MyOp(Op):
def make_node(self, input):
return Apply(self, [input], [vector()])
......@@ -4643,13 +4646,18 @@ def test_cvm_exception_handling():
def perform(self, node, inputs, outputs):
raise Exception("blah")
# def c_code(self, node, name, inputs, outputs, sub):
# fail = sub["fail"]
# return f"""
# PyErr_SetString(PyExc_Exception, "blah");
# {fail};
# """
myop = MyOp()
def scan_fn():
return myop(at.as_tensor(1))
mode = Mode(optimizer=None, linker="cvm")
res, _ = scan(scan_fn, n_steps=4, mode=mode)
res_fn = function([], res, mode=mode)
......@@ -5198,3 +5206,58 @@ def test_inner_get_vector_length():
res_fn = function([], res.shape)
assert np.array_equal(res_fn(), (10, 3))
@config.change_flags(mode=Mode("cvm", None))
def test_profile_info():
from aesara.scan.utils import ScanProfileStats
z, updates = scan(fn=lambda u: u + 1, sequences=[at.arange(10)], profile=True)
assert isinstance(z.owner.op, Scan)
fn = z.owner.op.fn
assert isinstance(fn.profile, ScanProfileStats)
assert fn.profile.name == "scan_fn"
# Set the `ScanProfileStats` name
z, updates = scan(
fn=lambda u: u + 1, sequences=[at.arange(10)], profile="profile_name"
)
assert isinstance(z.owner.op, Scan)
fn = z.owner.op.fn
assert isinstance(fn.profile, ScanProfileStats)
assert fn.profile.name == "profile_name"
# Use an existing profile object
profile = fn.profile
z, updates = scan(fn=lambda u: u + 1, sequences=[at.arange(10)], profile=profile)
assert isinstance(z.owner.op, Scan)
fn = z.owner.op.fn
assert fn.profile is profile
assert not profile.apply_time
assert profile.callcount == 0
assert profile.nbsteps == 0
assert profile.call_time == 0.0
assert fn.fn.call_times == [0.0]
assert fn.fn.call_counts == [0]
z_fn = function([], z)
_ = z_fn()
# assert profile.vm_call_time > 0
assert profile.callcount == 1
assert profile.nbsteps == 10
assert profile.call_time > 0
# Confirm that `VM.update_profile` was called
assert profile.apply_time
assert fn.fn.call_times == [0.0]
assert fn.fn.call_counts == [0]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论