提交 3e5ad997 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor SparseTensorType constructor and clone interface

上级 373eab13
from typing import Iterable, Optional, Union
import numpy as np
import scipy.sparse
from typing_extensions import Literal
import aesara
from aesara import scalar as aes
......@@ -7,6 +10,9 @@ from aesara.graph.type import HasDataType
from aesara.tensor.type import DenseTensorType, TensorType
SparsityTypes = Literal["csr", "csc", "bsr"]
def _is_sparse(x):
"""
......@@ -57,38 +63,39 @@ class SparseTensorType(TensorType, HasDataType):
}
ndim = 2
# Will be set to SparseVariable SparseConstant later.
variable_type = None
Constant = None
def __init__(self, format, dtype, shape=None, broadcastable=None, name=None):
if shape is None:
def __init__(
self,
format: SparsityTypes,
dtype: Union[str, np.dtype],
shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
name: Optional[str] = None,
broadcastable: Optional[Iterable[bool]] = None,
):
if shape is None and broadcastable is None:
shape = (None, None)
self.shape = shape
if not isinstance(format, str):
raise TypeError("The sparse format parameter must be a string")
if format in self.format_cls:
self.format = format
else:
raise NotImplementedError(
if format not in self.format_cls:
raise ValueError(
f'unsupported format "{format}" not in list',
)
if broadcastable is None:
broadcastable = [False, False]
super().__init__(dtype, shape, name=name)
self.format = format
super().__init__(dtype, shape=shape, name=name, broadcastable=broadcastable)
def clone(self, format=None, dtype=None, shape=None, **kwargs):
if format is None:
format = self.format
def clone(
self,
dtype=None,
shape=None,
broadcastable=None,
**kwargs,
):
format: Optional[SparsityTypes] = kwargs.pop("format", self.format)
if dtype is None:
dtype = self.dtype
if shape is None:
shape = self.shape
return type(self)(format, dtype, shape)
return type(self)(format, dtype, shape=shape, **kwargs)
def filter(self, value, strict=False, allow_downcast=None):
if (
......
import pytest
from aesara.sparse import matrix as sp_matrix
from aesara.sparse.type import SparseTensorType
from aesara.tensor import dmatrix
def test_clone():
st = SparseTensorType("csr", "float64")
def test_SparseTensorType_constructor():
st = SparseTensorType("csc", "float64")
assert st.format == "csc"
assert st.shape == (None, None)
st = SparseTensorType("bsr", "float64", shape=(None, 1))
assert st.format == "bsr"
assert st.shape == (None, 1)
with pytest.raises(ValueError):
SparseTensorType("blah", "float64")
def test_SparseTensorType_clone():
st = SparseTensorType("csr", "float64", shape=(3, None))
assert st == st.clone()
st_clone = st.clone(format="csc")
assert st_clone.format == "csc"
assert st_clone.dtype == st.dtype
assert st_clone.shape == st.shape
st_clone = st.clone(shape=(2, 1))
assert st_clone.format == st.format
assert st_clone.dtype == st.dtype
assert st_clone.shape == (2, 1)
def test_Sparse_convert_variable():
def test_SparseTensorType_convert_variable():
x = dmatrix(name="x")
y = sp_matrix("csc", dtype="float64", name="y")
z = sp_matrix("csr", dtype="float64", name="z")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论