提交 5043ac8e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update CheckAndRaise to handle ScalarType inputs

上级 0b9bba1a
......@@ -10,6 +10,7 @@ from aesara.graph.basic import Apply, Variable
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
from aesara.link.c.type import Generic
from aesara.scalar.basic import ScalarType
from aesara.tensor.type import DenseTensorType
......@@ -78,7 +79,10 @@ class CheckAndRaise(COp):
if not isinstance(value, Variable):
value = at.as_tensor_variable(value)
conds = [at.as_tensor_variable(c) for c in conds]
conds = [
at.as_tensor_variable(c) if not isinstance(c, Variable) else c
for c in conds
]
assert all(c.type.ndim == 0 for c in conds)
......@@ -102,7 +106,7 @@ class CheckAndRaise(COp):
return [[1]] + [[0]] * (len(node.inputs) - 1)
def c_code(self, node, name, inames, onames, props):
if not isinstance(node.inputs[0].type, DenseTensorType):
if not isinstance(node.inputs[0].type, (DenseTensorType, ScalarType)):
raise NotImplementedError(
f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}"
)
......@@ -112,25 +116,47 @@ class CheckAndRaise(COp):
fail_code = props["fail"]
param_struct_name = props["params"]
msg = self.msg.replace('"', '\\"').replace("\n", "\\n")
for idx, cond_name in enumerate(cond_names):
check.append(
f"""
if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
)
if isinstance(node.inputs[0].type, DenseTensorType):
check.append(
f"""
if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
)
else:
check.append(
f"""
if({cond_name} == 0) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
)
check = "\n".join(check)
res = f"""
{check}
Py_XDECREF({out_name});
{out_name} = {value_name};
Py_INCREF({value_name});
"""
if isinstance(node.inputs[0].type, DenseTensorType):
res = f"""
{check}
Py_XDECREF({out_name});
{out_name} = {value_name};
Py_INCREF({value_name});
"""
else:
res = f"""
{check}
{out_name} = {value_name};
"""
return res
def c_code_cache_version(self):
......
......@@ -729,6 +729,12 @@ class _scalar_py_operators:
dtype = property(lambda self: self.type.dtype)
"""The dtype of this scalar."""
@property
def shape(self):
from aesara.tensor.basic import as_tensor_variable
return as_tensor_variable([], ndim=1, dtype=np.int64)
# UNARY
def __abs__(self):
return abs(self)
......
......@@ -488,3 +488,17 @@ def test_mean(mode):
z = mean()
z_fn = aesara.function([], z, mode=mode)
assert z_fn() == 0
def test_shape():
a = float32("a")
assert isinstance(a.type, ScalarType)
assert a.shape.type.ndim == 1
assert a.shape.type.shape == (0,)
assert a.shape.type.dtype == "int64"
b = constant(2, name="b")
assert isinstance(b.type, ScalarType)
assert b.shape.type.ndim == 1
assert b.shape.type.shape == (0,)
assert b.shape.type.dtype == "int64"
......@@ -7,6 +7,7 @@ import aesara.tensor as at
from aesara.compile.mode import OPT_FAST_RUN, Mode
from aesara.graph.basic import Constant, equal_computations
from aesara.raise_op import Assert, CheckAndRaise, assert_op
from aesara.scalar.basic import ScalarType, float64
from aesara.sparse import as_sparse_variable
from tests import unittest_tools as utt
......@@ -94,6 +95,50 @@ def test_CheckAndRaise_basic_c(linker):
assert np.array_equal(y_fn(x_val), [x_val])
@pytest.mark.parametrize(
"linker",
[
pytest.param(
"cvm",
marks=pytest.mark.skipif(
not aesara.config.cxx,
reason="G++ not available, so we need to skip this test.",
),
),
"py",
],
)
def test_perform_CheckAndRaise_scalar(linker):
exc_msg = "this is the exception"
check_and_raise = CheckAndRaise(CustomException, exc_msg)
val = float64("val")
conds = (val > 0, val > 3)
y = check_and_raise(val, *conds)
assert all(isinstance(i.type, ScalarType) for i in y.owner.inputs)
assert isinstance(y.type, ScalarType)
mode = Mode(linker=linker)
y_fn = aesara.function([val], y, mode=mode)
with pytest.raises(CustomException, match=exc_msg):
y_fn(0.0)
assert y_fn(4.0) == 4.0
if linker == "cvm":
assert isinstance(
y_fn.maker.fgraph.outputs[0].owner.inputs[0].owner.op, CheckAndRaise
)
assert hasattr(y_fn.vm.thunks[-2], "cthunk")
(y_grad,) = aesara.grad(y, [val])
y_fn = aesara.function([val], y_grad, mode=Mode(linker, OPT_FAST_RUN))
assert np.array_equal(y_fn(4.0), 1.0)
class TestCheckAndRaiseInferShape(utt.InferShapeTester):
def setup_method(self):
super().setup_method()
......@@ -117,6 +162,16 @@ class TestCheckAndRaiseInferShape(utt.InferShapeTester):
[admat, adscal, bdscal], [out], [admat_val, adscal_val, bdscal_val], Assert
)
def test_infer_shape_scalar(self):
adscal = float64("adscal")
bdscal = float64("bdscal")
adscal_val = np.random.random()
bdscal_val = np.random.random() + 1
out = assert_op(adscal, bdscal)
self._compile_and_check(
[adscal, bdscal], [out], [adscal_val, bdscal_val], Assert
)
def test_CheckAndRaise_sparse_variable():
check_and_raise = CheckAndRaise(ValueError, "sparse_check")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论