提交 142031f4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a CheckAndRaise COp

上级 8a52fe6e
"""Symbolic Op for raising an exception."""
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from textwrap import indent
from typing import Tuple
import numpy as np
__authors__ = "James Bergstra " "PyMC Dev Team " "Aesara Developers"
__copyright__ = "(c) 2011, Universite de Montreal"
__license__ = "3-clause BSD License"
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, Op
from aesara.graph.params_type import ParamsType
from aesara.graph.type import Generic
__docformat__ = "restructuredtext en"
class ExceptionType(Generic):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
exception_type = ExceptionType()
class Raise(Op):
......@@ -32,3 +44,151 @@ class Raise(Op):
def perform(self, node, inputs, out_storage):
raise self.exc(self.msg)
class CheckAndRaise(COp):
"""An `Op` that checks conditions and raises an exception if they fail.
This `Op` returns its "value" argument if its condition arguments are all
``True``; otherwise, it raises a user-specified exception.
"""
_f16_ok = True
__props__ = ("msg", "exc_type")
view_map = {0: [0]}
check_input = False
params_type = ParamsType(exc_type=exception_type)
def __init__(self, exc_type, msg=""):
if not issubclass(exc_type, Exception):
raise ValueError("`exc_type` must be an Exception subclass")
self.exc_type = exc_type
self.msg = msg
def __str__(self):
return f"CheckAndRaise{{{self.exc_type}({self.msg})}}"
def __eq__(self, other):
if type(self) != type(other):
return False
if self.msg == other.msg and self.exc_type == other.exc_type:
return True
return False
def __hash__(self):
return hash((self.msg, self.exc_type))
def make_node(self, value: Variable, *conds: Tuple[Variable]):
"""
Parameters
==========
value
The value to return if `conds` all evaluate to ``True``; otherwise,
`self.exc_type` is raised.
conds
The conditions to evaluate.
"""
import aesara.tensor as at
if not isinstance(value, Variable):
value = at.as_tensor_variable(value)
conds = [at.as_tensor_variable(c) for c in conds]
assert all(c.type.ndim == 0 for c in conds)
return Apply(
self,
[value] + conds,
[value.type()],
)
def perform(self, node, inputs, outputs, params):
(out,) = outputs
val, *conds = inputs
out[0] = val
if not np.all(conds):
raise self.exc_type(self.msg)
def grad(self, input, output_gradients):
return output_gradients + [DisconnectedType()()] * (len(input) - 1)
def connection_pattern(self, node):
return [[1]] + [[0]] * (len(node.inputs) - 1)
def c_code(self, node, name, inames, onames, props):
value_name, *cond_names = inames
out_name = onames[0]
check = []
fail_code = props["fail"]
param_struct_name = props["params"]
msg = self.msg.replace('"', '\\"').replace("\n", "\\n")
for idx, cond_name in enumerate(cond_names):
check.append(
f"""
if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
)
check = "\n".join(check)
res = f"""
{check}
Py_XDECREF({out_name});
{out_name} = {value_name};
Py_INCREF({value_name});
"""
return res
def c_code_cache_version(self):
return (1, 0)
def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
class Assert(CheckAndRaise):
"""Implements assertion in a computational graph.
Returns the first parameter if the condition is ``True``; otherwise,
triggers `AssertionError`.
Notes
-----
This `Op` is a debugging feature. It can be removed from the graph
because of optimizations, and can hide some possible optimizations to
the optimizer. Specifically, removing happens if it can be determined
that condition will always be true. Also, the output of the Op must be
used in the function computing the graph, but it doesn't have to be
returned.
Examples
--------
>>> import aesara
>>> import aesara.tensor as aet
>>> from aesara.raise_op import Assert
>>> x = aet.vector("x")
>>> assert_op = Assert("This assert failed")
>>> func = aesara.function([x], assert_op(x, x.size < 2))
"""
def __init__(self, msg="Aesara Assert failed!"):
super().__init__(AssertionError, msg)
def __str__(self):
return f"Assert{{msg={self.msg}}}"
assert_op = Assert()
import numpy as np
import pytest
import aesara
import aesara.tensor as at
from aesara.compile.mode import OPT_FAST_RUN, Mode
from aesara.graph.basic import Constant, equal_computations
from aesara.raise_op import CheckAndRaise, assert_op
class CustomException(ValueError):
"""A custom user-created exception to throw."""
def test_CheckAndRaise_str():
exc_msg = "this is the exception"
check_and_raise = CheckAndRaise(CustomException, exc_msg)
assert (
str(check_and_raise)
== f"CheckAndRaise{{{CustomException}(this is the exception)}}"
)
def test_CheckAndRaise_pickle():
import pickle
exc_msg = "this is the exception"
check_and_raise = CheckAndRaise(CustomException, exc_msg)
y = check_and_raise(at.as_tensor(1), at.as_tensor(0))
y_str = pickle.dumps(y)
new_y = pickle.loads(y_str)
assert y.owner.op == new_y.owner.op
assert y.owner.op.msg == new_y.owner.op.msg
assert y.owner.op.exc_type == new_y.owner.op.exc_type
def test_CheckAndRaise_equal():
x, y = at.vectors("xy")
g1 = assert_op(x, (x > y).all())
g2 = assert_op(x, (x > y).all())
assert equal_computations([g1], [g2])
def test_CheckAndRaise_validation():
with pytest.raises(ValueError):
CheckAndRaise(str)
g1 = assert_op(np.array(1.0))
assert isinstance(g1.owner.inputs[0], Constant)
@pytest.mark.parametrize(
"linker",
[
pytest.param(
"cvm",
marks=pytest.mark.skipif(
not aesara.config.cxx,
reason="G++ not available, so we need to skip this test.",
),
),
"py",
],
)
def test_CheckAndRaise_basic_c(linker):
exc_msg = "this is the exception"
check_and_raise = CheckAndRaise(CustomException, exc_msg)
x = at.scalar()
conds = at.scalar()
y = check_and_raise(at.as_tensor(1), conds)
y_fn = aesara.function([conds], y, mode=Mode(linker))
with pytest.raises(CustomException, match=exc_msg):
y_fn(0)
y = check_and_raise(at.as_tensor(1), conds)
y_fn = aesara.function([conds], y.shape, mode=Mode(linker, OPT_FAST_RUN))
assert np.array_equal(y_fn(0), [])
y = check_and_raise(x, at.as_tensor(0))
y_grad = aesara.grad(y, [x])
y_fn = aesara.function([x], y_grad, mode=Mode(linker, OPT_FAST_RUN))
assert np.array_equal(y_fn(1.0), [1.0])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论