提交 d47ce12b authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Do not use c_code for non-dense inputs in CheckAndRaise

上级 5e7b095c
......@@ -10,6 +10,7 @@ from aesara.graph.basic import Apply, Variable
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
from aesara.link.c.type import Generic
from aesara.tensor.type import DenseTensorType
class ExceptionType(Generic):
......@@ -101,6 +102,10 @@ class CheckAndRaise(COp):
return [[1]] + [[0]] * (len(node.inputs) - 1)
def c_code(self, node, name, inames, onames, props):
if not isinstance(node.inputs[0].type, DenseTensorType):
raise NotImplementedError(
f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}"
)
value_name, *cond_names = inames
out_name = onames[0]
check = []
......@@ -129,7 +134,7 @@ class CheckAndRaise(COp):
return res
def c_code_cache_version(self):
return (1, 0)
return (1, 1)
def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
......
import numpy as np
import pytest
import scipy.sparse
import aesara
import aesara.tensor as at
from aesara.compile.mode import OPT_FAST_RUN, Mode
from aesara.graph.basic import Constant, equal_computations
from aesara.raise_op import Assert, CheckAndRaise, assert_op
from aesara.sparse import as_sparse_variable
from tests import unittest_tools as utt
......@@ -114,3 +116,18 @@ class TestCheckAndRaiseInferShape(utt.InferShapeTester):
self._compile_and_check(
[admat, adscal, bdscal], [out], [admat_val, adscal_val, bdscal_val], Assert
)
def test_CheckAndRaise_sparse_variable():
check_and_raise = CheckAndRaise(ValueError, "sparse_check")
spe1 = scipy.sparse.csc_matrix(scipy.sparse.eye(5, 3))
aspe1 = as_sparse_variable(spe1)
a1 = check_and_raise(aspe1, aspe1.sum() > 2)
assert a1.sum().eval() == 3
spe2 = scipy.sparse.csc_matrix(scipy.sparse.eye(5, 1))
aspe2 = as_sparse_variable(spe2)
a2 = check_and_raise(aspe1, aspe2.sum() > 2)
with pytest.raises(ValueError, match="sparse_check"):
a2.sum().eval()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论