提交 7d0edb87 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make SparseTensorType.filter match TensorType in error behavior

上级 3e5ad997
......@@ -6,6 +6,7 @@ from typing_extensions import Literal
import aesara
from aesara import scalar as aes
from aesara.graph.basic import Variable
from aesara.graph.type import HasDataType
from aesara.tensor.type import DenseTensorType, TensorType
......@@ -98,6 +99,13 @@ class SparseTensorType(TensorType, HasDataType):
return type(self)(format, dtype, shape=shape, **kwargs)
def filter(self, value, strict=False, allow_downcast=None):
if isinstance(value, Variable):
raise TypeError(
"Expected an array-like object, but found a Variable: "
"maybe you are trying to call a function on a (possibly "
"shared) variable instead of a numeric array?"
)
if (
isinstance(value, self.format_cls[self.format])
and value.dtype == self.dtype
......@@ -117,13 +125,10 @@ class SparseTensorType(TensorType, HasDataType):
data = self.format_cls[self.format](value)
up_dtype = aes.upcast(self.dtype, data.dtype)
if up_dtype != self.dtype:
raise NotImplementedError(
f"Expected {self.dtype} dtype but got {data.dtype}"
)
raise TypeError(f"Expected {self.dtype} dtype but got {data.dtype}")
sp = data.astype(up_dtype)
if sp.format != self.format:
raise NotImplementedError()
assert sp.format == self.format
return sp
......
import pytest
import scipy as sp
from aesara.sparse import matrix as sp_matrix
from aesara.sparse.type import SparseTensorType
......@@ -52,3 +53,31 @@ def test_SparseTensorType_convert_variable():
# we would need to added sparse/dense logic to `TensorType`, and we don't
# want to do that.
assert x.type.convert_variable(y) is y
def test_SparseTensorType_filter():
y = sp_matrix("csc", dtype="float64", name="y")
z = sp_matrix("csr", dtype="float64", name="z")
w = sp_matrix("csr", dtype="float32", name="z")
with pytest.raises(TypeError, match="Expected an array-like"):
y.type.filter(dmatrix())
x = sp.sparse.csc_matrix(sp.sparse.eye(5, 3))
x_res = y.type.filter(x)
assert x is x_res
x_res = z.type.filter(x)
assert x_res.format == "csr"
with pytest.raises(TypeError):
x_res = z.type.filter(x, strict=True)
x_res = w.type.filter(x, allow_downcast=True)
assert x_res.dtype == "float32"
x_res = z.type.filter(x.astype("float32"), allow_downcast=True)
assert x_res.dtype == "float64"
with pytest.raises(TypeError, match=".*dtype but got.*"):
w.type.filter(x)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论