提交 74f80840 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move error and profiler handling to the thunk in Scan's Cython implementation

上级 c997333d
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -43,7 +43,6 @@ relies on the following elements to work properly : ...@@ -43,7 +43,6 @@ relies on the following elements to work properly :
""" """
import dataclasses import dataclasses
import logging import logging
import time import time
...@@ -1401,39 +1400,67 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1401,39 +1400,67 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
getattr(out, "ndim", None) for out in node.outputs getattr(out, "ndim", None) for out in node.outputs
) )
from aesara.scan.utils import InnerFunctionError
# TODO: Extract `Capsule` object and use that
# c_thunk = getattr(self.fn.fn.thunks[0], "cthunk", None)
# if len(self.fn.fn.thunks) == 1 and c_thunk:
# thunk_capsule = c_thunk.cthunk
# # We need to perform the following after calling
# # the thunk function:
# # for o in node.outputs:
# # compute_map[o][0] = True
def p(node, inputs, outputs): def p(node, inputs, outputs):
t0_call = time.perf_counter() t0_call = time.perf_counter()
t_fn = scan_perform_ext.perform( try:
self.n_shared_outs, t_fn, n_steps = scan_perform_ext.perform(
self.n_mit_mot_outs, self.n_shared_outs,
self.n_seqs, self.n_mit_mot_outs,
self.n_mit_mot, self.n_seqs,
self.n_mit_sot, self.n_mit_mot,
self.n_sit_sot, self.n_mit_sot,
self.n_nit_sot, self.n_sit_sot,
self.as_while, self.n_nit_sot,
cython_mintaps, self.as_while,
self.tap_array, cython_mintaps,
tap_array_len, self.tap_array,
cython_vector_seqs, tap_array_len,
cython_vector_outs, cython_vector_seqs,
self.mit_mot_out_slices, cython_vector_outs,
cython_mitmots_preallocated, self.mit_mot_out_slices,
cython_inps_is_tensor, cython_mitmots_preallocated,
cython_outs_is_tensor, cython_inps_is_tensor,
inner_input_storage, cython_outs_is_tensor,
inner_output_storage, inner_input_storage,
getattr(self.fn.fn, "need_update_inputs", True), inner_output_storage,
inner_input_needs_update, getattr(self.fn.fn, "need_update_inputs", True),
self.fn, inner_input_needs_update,
cython_destroy_map, cython_destroy_map,
inputs, inputs,
outputs, outputs,
outer_output_dtypes, outer_output_dtypes,
outer_output_ndims, outer_output_ndims,
) self.fn.fn,
)
except InnerFunctionError as exc:
exc_type = type(exc.args[0])
exc_value = exc.args[0]
exc_trace = exc.args[1]
if hasattr(self.fn.fn, "position_of_error") and hasattr(
self.fn.fn, "thunks"
):
raise_with_op(
self.fn.maker.fgraph,
self.fn.fn.nodes[self.fn.fn.position_of_error],
self.fn.fn.thunks[self.fn.fn.position_of_error],
exc_info=(exc_type, exc_value, exc_trace),
)
else:
raise exc_value.with_traceback(exc_trace)
t_call = time.perf_counter() - t0_call t_call = time.perf_counter() - t0_call
...@@ -1442,7 +1469,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1442,7 +1469,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if type(profile) is not bool and profile: if type(profile) is not bool and profile:
profile.vm_call_time += t_fn profile.vm_call_time += t_fn
profile.callcount += 1 profile.callcount += 1
profile.nbsteps += outputs[0] profile.nbsteps += n_steps
profile.call_time += t_call profile.call_time += t_call
if hasattr(self.fn.fn, "update_profile"): if hasattr(self.fn.fn, "update_profile"):
self.fn.fn.update_profile(profile) self.fn.fn.update_profile(profile)
......
...@@ -53,12 +53,13 @@ cimport numpy ...@@ -53,12 +53,13 @@ cimport numpy
import copy import copy
import time import time
import sys
from aesara.link.utils import raise_with_op from aesara.scan.utils import InnerFunctionError
def get_version(): def get_version():
return 0.302 return 0.312
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -83,13 +84,13 @@ def perform( ...@@ -83,13 +84,13 @@ def perform(
list inner_output_storage, list inner_output_storage,
bint need_update_inputs, bint need_update_inputs,
tuple inner_input_needs_update, tuple inner_input_needs_update,
fnct,
numpy.ndarray[numpy.int32_t,ndim=1] destroy_map, numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
list outer_inputs, list outer_inputs,
list outer_outputs, list outer_outputs,
tuple outer_output_dtypes, tuple outer_output_dtypes,
tuple outer_output_ndims, tuple outer_output_ndims,
): fn,
) -> (float, int):
""" """
Parameters Parameters
---------- ----------
...@@ -160,6 +161,8 @@ def perform( ...@@ -160,6 +161,8 @@ def perform(
The dtypes for each outer output. The dtypes for each outer output.
outer_output_ndims outer_output_ndims
The number of dimensions for each outer output. The number of dimensions for each outer output.
fn
The inner function thunk.
""" """
# 1. Unzip the number of steps and sequences. If number of steps is # 1. Unzip the number of steps and sequences. If number of steps is
...@@ -258,7 +261,7 @@ def perform( ...@@ -258,7 +261,7 @@ def perform(
outer_outputs[idx][0] = numpy.empty((0,) * outer_output_ndims[idx], dtype=outer_output_dtypes[idx]) outer_outputs[idx][0] = numpy.empty((0,) * outer_output_ndims[idx], dtype=outer_output_dtypes[idx])
else: else:
outer_outputs[idx][0] = None outer_outputs[idx][0] = None
return return 0.0, 0
for idx in range(n_outs + n_nit_sot): for idx in range(n_outs + n_nit_sot):
pos[idx] = -mintaps[idx] % store_steps[idx] pos[idx] = -mintaps[idx] % store_steps[idx]
...@@ -282,8 +285,6 @@ def perform( ...@@ -282,8 +285,6 @@ def perform(
for idx in range(len(other_args)): for idx in range(len(other_args)):
inner_input_storage[<unsigned int>(idx+offset)][0] = other_args[idx] inner_input_storage[<unsigned int>(idx+offset)][0] = other_args[idx]
fn = fnct.fn
i = 0 i = 0
cond = 1 cond = 1
############## THE MAIN LOOP ######################### ############## THE MAIN LOOP #########################
...@@ -398,25 +399,8 @@ def perform( ...@@ -398,25 +399,8 @@ def perform(
try: try:
fn() fn()
except Exception: except Exception as exc:
if hasattr(fn, 'position_of_error'): raise InnerFunctionError(exc, sys.exc_info()[-1])
# this is a new vm-provided function
# the C VM needs this because the exception manipulation
# done by raise_with_op is not implemented in C.
if hasattr(fn, 'thunks'):
# For the CVM
raise_with_op(fnct.maker.fgraph,
fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error])
else:
# For the c linker
# We don't have access from python to all the
# temps values So for now, we just don't print
# the extra shapes/strides info
raise_with_op(fnct.maker.fgraph, fn.nodes[fn.position_of_error])
else:
# old-style linkers raise their own exceptions
raise
dt_fn = time.time() - t0_fn dt_fn = time.time() - t0_fn
t_fn += dt_fn t_fn += dt_fn
...@@ -625,4 +609,4 @@ def perform( ...@@ -625,4 +609,4 @@ def perform(
for s in inner_output_storage: for s in inner_output_storage:
s[0] = None s[0] = None
return t_fn return t_fn, i
...@@ -21,7 +21,7 @@ if not config.cxx: ...@@ -21,7 +21,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.302 # must match constant returned in function get_version() version = 0.312 # must match constant returned in function get_version()
need_reload = False need_reload = False
......
...@@ -35,6 +35,10 @@ if TYPE_CHECKING: ...@@ -35,6 +35,10 @@ if TYPE_CHECKING:
_logger = logging.getLogger("aesara.scan.utils") _logger = logging.getLogger("aesara.scan.utils")
class InnerFunctionError(Exception):
"""An exception indicating that an error occurred in `Scan`'s inner function."""
def safe_new( def safe_new(
x: Variable, tag: str = "", dtype: Optional[Union[str, np.dtype]] = None x: Variable, tag: str = "", dtype: Optional[Union[str, np.dtype]] = None
) -> Variable: ) -> Variable:
...@@ -126,8 +130,8 @@ class until: ...@@ -126,8 +130,8 @@ class until:
class ScanProfileStats(ProfileStats): class ScanProfileStats(ProfileStats):
show_sum = False show_sum = False
callcount = 0.0 callcount = 0
nbsteps = 0.0 nbsteps = 0
call_time = 0.0 call_time = 0.0
def __init__(self, atexit_print=True, name=None, **kwargs): def __init__(self, atexit_print=True, name=None, **kwargs):
......
...@@ -4635,7 +4635,10 @@ class TestScan: ...@@ -4635,7 +4635,10 @@ class TestScan:
@pytest.mark.skipif( @pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test." not config.cxx, reason="G++ not available, so we need to skip this test."
) )
def test_cvm_exception_handling(): @pytest.mark.parametrize(
"mode", [Mode(linker="c|py", optimizer=None), Mode(linker="cvm", optimizer=None)]
)
def test_cvm_exception_handling(mode):
class MyOp(Op): class MyOp(Op):
def make_node(self, input): def make_node(self, input):
return Apply(self, [input], [vector()]) return Apply(self, [input], [vector()])
...@@ -4643,13 +4646,18 @@ def test_cvm_exception_handling(): ...@@ -4643,13 +4646,18 @@ def test_cvm_exception_handling():
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
raise Exception("blah") raise Exception("blah")
# def c_code(self, node, name, inputs, outputs, sub):
# fail = sub["fail"]
# return f"""
# PyErr_SetString(PyExc_Exception, "blah");
# {fail};
# """
myop = MyOp() myop = MyOp()
def scan_fn(): def scan_fn():
return myop(at.as_tensor(1)) return myop(at.as_tensor(1))
mode = Mode(optimizer=None, linker="cvm")
res, _ = scan(scan_fn, n_steps=4, mode=mode) res, _ = scan(scan_fn, n_steps=4, mode=mode)
res_fn = function([], res, mode=mode) res_fn = function([], res, mode=mode)
...@@ -5198,3 +5206,58 @@ def test_inner_get_vector_length(): ...@@ -5198,3 +5206,58 @@ def test_inner_get_vector_length():
res_fn = function([], res.shape) res_fn = function([], res.shape)
assert np.array_equal(res_fn(), (10, 3)) assert np.array_equal(res_fn(), (10, 3))
@config.change_flags(mode=Mode("cvm", None))
def test_profile_info():
from aesara.scan.utils import ScanProfileStats
z, updates = scan(fn=lambda u: u + 1, sequences=[at.arange(10)], profile=True)
assert isinstance(z.owner.op, Scan)
fn = z.owner.op.fn
assert isinstance(fn.profile, ScanProfileStats)
assert fn.profile.name == "scan_fn"
# Set the `ScanProfileStats` name
z, updates = scan(
fn=lambda u: u + 1, sequences=[at.arange(10)], profile="profile_name"
)
assert isinstance(z.owner.op, Scan)
fn = z.owner.op.fn
assert isinstance(fn.profile, ScanProfileStats)
assert fn.profile.name == "profile_name"
# Use an existing profile object
profile = fn.profile
z, updates = scan(fn=lambda u: u + 1, sequences=[at.arange(10)], profile=profile)
assert isinstance(z.owner.op, Scan)
fn = z.owner.op.fn
assert fn.profile is profile
assert not profile.apply_time
assert profile.callcount == 0
assert profile.nbsteps == 0
assert profile.call_time == 0.0
assert fn.fn.call_times == [0.0]
assert fn.fn.call_counts == [0]
z_fn = function([], z)
_ = z_fn()
# assert profile.vm_call_time > 0
assert profile.callcount == 1
assert profile.nbsteps == 10
assert profile.call_time > 0
# Confirm that `VM.update_profile` was called
assert profile.apply_time
assert fn.fn.call_times == [0.0]
assert fn.fn.call_counts == [0]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论