提交 3e5ad997 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor SparseTensorType constructor and clone interface

上级 373eab13
from typing import Iterable, Optional, Union
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from typing_extensions import Literal
import aesara import aesara
from aesara import scalar as aes from aesara import scalar as aes
...@@ -7,6 +10,9 @@ from aesara.graph.type import HasDataType ...@@ -7,6 +10,9 @@ from aesara.graph.type import HasDataType
from aesara.tensor.type import DenseTensorType, TensorType from aesara.tensor.type import DenseTensorType, TensorType
SparsityTypes = Literal["csr", "csc", "bsr"]
def _is_sparse(x): def _is_sparse(x):
""" """
...@@ -57,38 +63,39 @@ class SparseTensorType(TensorType, HasDataType): ...@@ -57,38 +63,39 @@ class SparseTensorType(TensorType, HasDataType):
} }
ndim = 2 ndim = 2
# Will be set to SparseVariable SparseConstant later. def __init__(
variable_type = None self,
Constant = None format: SparsityTypes,
dtype: Union[str, np.dtype],
def __init__(self, format, dtype, shape=None, broadcastable=None, name=None): shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
if shape is None: name: Optional[str] = None,
broadcastable: Optional[Iterable[bool]] = None,
):
if shape is None and broadcastable is None:
shape = (None, None) shape = (None, None)
self.shape = shape if format not in self.format_cls:
raise ValueError(
if not isinstance(format, str):
raise TypeError("The sparse format parameter must be a string")
if format in self.format_cls:
self.format = format
else:
raise NotImplementedError(
f'unsupported format "{format}" not in list', f'unsupported format "{format}" not in list',
) )
if broadcastable is None:
broadcastable = [False, False]
super().__init__(dtype, shape, name=name) self.format = format
super().__init__(dtype, shape=shape, name=name, broadcastable=broadcastable)
def clone(self, format=None, dtype=None, shape=None, **kwargs): def clone(
if format is None: self,
format = self.format dtype=None,
shape=None,
broadcastable=None,
**kwargs,
):
format: Optional[SparsityTypes] = kwargs.pop("format", self.format)
if dtype is None: if dtype is None:
dtype = self.dtype dtype = self.dtype
if shape is None: if shape is None:
shape = self.shape shape = self.shape
return type(self)(format, dtype, shape) return type(self)(format, dtype, shape=shape, **kwargs)
def filter(self, value, strict=False, allow_downcast=None): def filter(self, value, strict=False, allow_downcast=None):
if ( if (
......
import pytest
from aesara.sparse import matrix as sp_matrix from aesara.sparse import matrix as sp_matrix
from aesara.sparse.type import SparseTensorType from aesara.sparse.type import SparseTensorType
from aesara.tensor import dmatrix from aesara.tensor import dmatrix
def test_clone(): def test_SparseTensorType_constructor():
st = SparseTensorType("csr", "float64") st = SparseTensorType("csc", "float64")
assert st.format == "csc"
assert st.shape == (None, None)
st = SparseTensorType("bsr", "float64", shape=(None, 1))
assert st.format == "bsr"
assert st.shape == (None, 1)
with pytest.raises(ValueError):
SparseTensorType("blah", "float64")
def test_SparseTensorType_clone():
st = SparseTensorType("csr", "float64", shape=(3, None))
assert st == st.clone() assert st == st.clone()
st_clone = st.clone(format="csc")
assert st_clone.format == "csc"
assert st_clone.dtype == st.dtype
assert st_clone.shape == st.shape
st_clone = st.clone(shape=(2, 1))
assert st_clone.format == st.format
assert st_clone.dtype == st.dtype
assert st_clone.shape == (2, 1)
def test_Sparse_convert_variable(): def test_SparseTensorType_convert_variable():
x = dmatrix(name="x") x = dmatrix(name="x")
y = sp_matrix("csc", dtype="float64", name="y") y = sp_matrix("csc", dtype="float64", name="y")
z = sp_matrix("csr", dtype="float64", name="z") z = sp_matrix("csr", dtype="float64", name="z")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论