提交 7d0edb87 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make SparseTensorType.filter match TensorType in error behavior

上级 3e5ad997
...@@ -6,6 +6,7 @@ from typing_extensions import Literal ...@@ -6,6 +6,7 @@ from typing_extensions import Literal
import aesara import aesara
from aesara import scalar as aes from aesara import scalar as aes
from aesara.graph.basic import Variable
from aesara.graph.type import HasDataType from aesara.graph.type import HasDataType
from aesara.tensor.type import DenseTensorType, TensorType from aesara.tensor.type import DenseTensorType, TensorType
...@@ -98,6 +99,13 @@ class SparseTensorType(TensorType, HasDataType): ...@@ -98,6 +99,13 @@ class SparseTensorType(TensorType, HasDataType):
return type(self)(format, dtype, shape=shape, **kwargs) return type(self)(format, dtype, shape=shape, **kwargs)
def filter(self, value, strict=False, allow_downcast=None): def filter(self, value, strict=False, allow_downcast=None):
if isinstance(value, Variable):
raise TypeError(
"Expected an array-like object, but found a Variable: "
"maybe you are trying to call a function on a (possibly "
"shared) variable instead of a numeric array?"
)
if ( if (
isinstance(value, self.format_cls[self.format]) isinstance(value, self.format_cls[self.format])
and value.dtype == self.dtype and value.dtype == self.dtype
...@@ -117,13 +125,10 @@ class SparseTensorType(TensorType, HasDataType): ...@@ -117,13 +125,10 @@ class SparseTensorType(TensorType, HasDataType):
data = self.format_cls[self.format](value) data = self.format_cls[self.format](value)
up_dtype = aes.upcast(self.dtype, data.dtype) up_dtype = aes.upcast(self.dtype, data.dtype)
if up_dtype != self.dtype: if up_dtype != self.dtype:
raise NotImplementedError( raise TypeError(f"Expected {self.dtype} dtype but got {data.dtype}")
f"Expected {self.dtype} dtype but got {data.dtype}"
)
sp = data.astype(up_dtype) sp = data.astype(up_dtype)
if sp.format != self.format: assert sp.format == self.format
raise NotImplementedError()
return sp return sp
......
import pytest import pytest
import scipy as sp
from aesara.sparse import matrix as sp_matrix from aesara.sparse import matrix as sp_matrix
from aesara.sparse.type import SparseTensorType from aesara.sparse.type import SparseTensorType
...@@ -52,3 +53,31 @@ def test_SparseTensorType_convert_variable(): ...@@ -52,3 +53,31 @@ def test_SparseTensorType_convert_variable():
# we would need to added sparse/dense logic to `TensorType`, and we don't # we would need to added sparse/dense logic to `TensorType`, and we don't
# want to do that. # want to do that.
assert x.type.convert_variable(y) is y assert x.type.convert_variable(y) is y
def test_SparseTensorType_filter():
y = sp_matrix("csc", dtype="float64", name="y")
z = sp_matrix("csr", dtype="float64", name="z")
w = sp_matrix("csr", dtype="float32", name="z")
with pytest.raises(TypeError, match="Expected an array-like"):
y.type.filter(dmatrix())
x = sp.sparse.csc_matrix(sp.sparse.eye(5, 3))
x_res = y.type.filter(x)
assert x is x_res
x_res = z.type.filter(x)
assert x_res.format == "csr"
with pytest.raises(TypeError):
x_res = z.type.filter(x, strict=True)
x_res = w.type.filter(x, allow_downcast=True)
assert x_res.dtype == "float32"
x_res = z.type.filter(x.astype("float32"), allow_downcast=True)
assert x_res.dtype == "float64"
with pytest.raises(TypeError, match=".*dtype but got.*"):
w.type.filter(x)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论