提交 5043ac8e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update CheckAndRaise to handle ScalarType inputs

上级 0b9bba1a
...@@ -10,6 +10,7 @@ from aesara.graph.basic import Apply, Variable ...@@ -10,6 +10,7 @@ from aesara.graph.basic import Apply, Variable
from aesara.link.c.op import COp from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType from aesara.link.c.params_type import ParamsType
from aesara.link.c.type import Generic from aesara.link.c.type import Generic
from aesara.scalar.basic import ScalarType
from aesara.tensor.type import DenseTensorType from aesara.tensor.type import DenseTensorType
...@@ -78,7 +79,10 @@ class CheckAndRaise(COp): ...@@ -78,7 +79,10 @@ class CheckAndRaise(COp):
if not isinstance(value, Variable): if not isinstance(value, Variable):
value = at.as_tensor_variable(value) value = at.as_tensor_variable(value)
conds = [at.as_tensor_variable(c) for c in conds] conds = [
at.as_tensor_variable(c) if not isinstance(c, Variable) else c
for c in conds
]
assert all(c.type.ndim == 0 for c in conds) assert all(c.type.ndim == 0 for c in conds)
...@@ -102,7 +106,7 @@ class CheckAndRaise(COp): ...@@ -102,7 +106,7 @@ class CheckAndRaise(COp):
return [[1]] + [[0]] * (len(node.inputs) - 1) return [[1]] + [[0]] * (len(node.inputs) - 1)
def c_code(self, node, name, inames, onames, props): def c_code(self, node, name, inames, onames, props):
if not isinstance(node.inputs[0].type, DenseTensorType): if not isinstance(node.inputs[0].type, (DenseTensorType, ScalarType)):
raise NotImplementedError( raise NotImplementedError(
f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}" f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}"
) )
...@@ -112,25 +116,47 @@ class CheckAndRaise(COp): ...@@ -112,25 +116,47 @@ class CheckAndRaise(COp):
fail_code = props["fail"] fail_code = props["fail"]
param_struct_name = props["params"] param_struct_name = props["params"]
msg = self.msg.replace('"', '\\"').replace("\n", "\\n") msg = self.msg.replace('"', '\\"').replace("\n", "\\n")
for idx, cond_name in enumerate(cond_names): for idx, cond_name in enumerate(cond_names):
check.append( if isinstance(node.inputs[0].type, DenseTensorType):
f""" check.append(
if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{ f"""
PyObject * exc_type = {param_struct_name}->exc_type; if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{
Py_INCREF(exc_type); PyObject * exc_type = {param_struct_name}->exc_type;
PyErr_SetString(exc_type, "{msg}"); Py_INCREF(exc_type);
Py_XDECREF(exc_type); PyErr_SetString(exc_type, "{msg}");
{indent(fail_code, " " * 4)} Py_XDECREF(exc_type);
}} {indent(fail_code, " " * 4)}
""" }}
) """
)
else:
check.append(
f"""
if({cond_name} == 0) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
)
check = "\n".join(check) check = "\n".join(check)
res = f"""
{check} if isinstance(node.inputs[0].type, DenseTensorType):
Py_XDECREF({out_name}); res = f"""
{out_name} = {value_name}; {check}
Py_INCREF({value_name}); Py_XDECREF({out_name});
""" {out_name} = {value_name};
Py_INCREF({value_name});
"""
else:
res = f"""
{check}
{out_name} = {value_name};
"""
return res return res
def c_code_cache_version(self): def c_code_cache_version(self):
......
...@@ -729,6 +729,12 @@ class _scalar_py_operators: ...@@ -729,6 +729,12 @@ class _scalar_py_operators:
dtype = property(lambda self: self.type.dtype) dtype = property(lambda self: self.type.dtype)
"""The dtype of this scalar.""" """The dtype of this scalar."""
@property
def shape(self):
from aesara.tensor.basic import as_tensor_variable
return as_tensor_variable([], ndim=1, dtype=np.int64)
# UNARY # UNARY
def __abs__(self): def __abs__(self):
return abs(self) return abs(self)
......
...@@ -488,3 +488,17 @@ def test_mean(mode): ...@@ -488,3 +488,17 @@ def test_mean(mode):
z = mean() z = mean()
z_fn = aesara.function([], z, mode=mode) z_fn = aesara.function([], z, mode=mode)
assert z_fn() == 0 assert z_fn() == 0
def test_shape():
a = float32("a")
assert isinstance(a.type, ScalarType)
assert a.shape.type.ndim == 1
assert a.shape.type.shape == (0,)
assert a.shape.type.dtype == "int64"
b = constant(2, name="b")
assert isinstance(b.type, ScalarType)
assert b.shape.type.ndim == 1
assert b.shape.type.shape == (0,)
assert b.shape.type.dtype == "int64"
...@@ -7,6 +7,7 @@ import aesara.tensor as at ...@@ -7,6 +7,7 @@ import aesara.tensor as at
from aesara.compile.mode import OPT_FAST_RUN, Mode from aesara.compile.mode import OPT_FAST_RUN, Mode
from aesara.graph.basic import Constant, equal_computations from aesara.graph.basic import Constant, equal_computations
from aesara.raise_op import Assert, CheckAndRaise, assert_op from aesara.raise_op import Assert, CheckAndRaise, assert_op
from aesara.scalar.basic import ScalarType, float64
from aesara.sparse import as_sparse_variable from aesara.sparse import as_sparse_variable
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -94,6 +95,50 @@ def test_CheckAndRaise_basic_c(linker): ...@@ -94,6 +95,50 @@ def test_CheckAndRaise_basic_c(linker):
assert np.array_equal(y_fn(x_val), [x_val]) assert np.array_equal(y_fn(x_val), [x_val])
@pytest.mark.parametrize(
"linker",
[
pytest.param(
"cvm",
marks=pytest.mark.skipif(
not aesara.config.cxx,
reason="G++ not available, so we need to skip this test.",
),
),
"py",
],
)
def test_perform_CheckAndRaise_scalar(linker):
exc_msg = "this is the exception"
check_and_raise = CheckAndRaise(CustomException, exc_msg)
val = float64("val")
conds = (val > 0, val > 3)
y = check_and_raise(val, *conds)
assert all(isinstance(i.type, ScalarType) for i in y.owner.inputs)
assert isinstance(y.type, ScalarType)
mode = Mode(linker=linker)
y_fn = aesara.function([val], y, mode=mode)
with pytest.raises(CustomException, match=exc_msg):
y_fn(0.0)
assert y_fn(4.0) == 4.0
if linker == "cvm":
assert isinstance(
y_fn.maker.fgraph.outputs[0].owner.inputs[0].owner.op, CheckAndRaise
)
assert hasattr(y_fn.vm.thunks[-2], "cthunk")
(y_grad,) = aesara.grad(y, [val])
y_fn = aesara.function([val], y_grad, mode=Mode(linker, OPT_FAST_RUN))
assert np.array_equal(y_fn(4.0), 1.0)
class TestCheckAndRaiseInferShape(utt.InferShapeTester): class TestCheckAndRaiseInferShape(utt.InferShapeTester):
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
...@@ -117,6 +162,16 @@ class TestCheckAndRaiseInferShape(utt.InferShapeTester): ...@@ -117,6 +162,16 @@ class TestCheckAndRaiseInferShape(utt.InferShapeTester):
[admat, adscal, bdscal], [out], [admat_val, adscal_val, bdscal_val], Assert [admat, adscal, bdscal], [out], [admat_val, adscal_val, bdscal_val], Assert
) )
def test_infer_shape_scalar(self):
adscal = float64("adscal")
bdscal = float64("bdscal")
adscal_val = np.random.random()
bdscal_val = np.random.random() + 1
out = assert_op(adscal, bdscal)
self._compile_and_check(
[adscal, bdscal], [out], [adscal_val, bdscal_val], Assert
)
def test_CheckAndRaise_sparse_variable(): def test_CheckAndRaise_sparse_variable():
check_and_raise = CheckAndRaise(ValueError, "sparse_check") check_and_raise = CheckAndRaise(ValueError, "sparse_check")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论