提交 fdf2a23a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make sparse tensor types extend the existing tensor types

上级 31ab8fdd
......@@ -49,6 +49,7 @@ from aesara.tensor.type import TensorType
from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes
from aesara.tensor.type import iscalar, ivector, scalar, tensor, vector
from aesara.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators
sparse_formats = ["csc", "csr"]
......@@ -126,8 +127,7 @@ def _is_dense(x):
return isinstance(x, np.ndarray)
# Wrapper type
def as_sparse_variable(x, name=None):
def as_sparse_variable(x, name=None, ndim=None, **kwargs):
"""
Wrapper around SparseVariable constructor to construct
a Variable with a sparse matrix with the same dtype and
......@@ -250,7 +250,7 @@ def sp_zeros_like(x):
)
class _sparse_py_operators:
class _sparse_py_operators(_tensor_py_operators):
T = property(
lambda self: transpose(self), doc="Return aliased transpose of self (read-only)"
)
......@@ -361,8 +361,7 @@ class _sparse_py_operators:
return ret
class SparseVariable(_sparse_py_operators, Variable):
dtype = property(lambda self: self.type.dtype)
class SparseVariable(_sparse_py_operators, TensorVariable):
format = property(lambda self: self.type.format)
def __str__(self):
......@@ -395,8 +394,7 @@ class SparseConstantSignature(tuple):
return hash_from_sparse(d)
class SparseConstant(Constant, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype)
class SparseConstant(TensorConstant, _sparse_py_operators):
format = property(lambda self: self.type.format)
def signature(self):
......@@ -448,7 +446,7 @@ csc_fmatrix = SparseType(format="csc", dtype="float32")
csr_fmatrix = SparseType(format="csr", dtype="float32")
bsr_fmatrix = SparseType(format="bsr", dtype="float32")
all_dtypes = SparseType.dtype_set
all_dtypes = list(SparseType.dtype_specs_map.keys())
complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"]
float_dtypes = [t for t in all_dtypes if t[:5] == "float"]
int_dtypes = [t for t in all_dtypes if t[:3] == "int"]
......@@ -926,6 +924,12 @@ class DenseFromSparse(Op):
def __str__(self):
return f"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}"
def __call__(self, x):
if not isinstance(x.type, SparseType):
return x
return super().__call__(x)
def make_node(self, x):
x = as_sparse_variable(x)
return Apply(
......@@ -1003,6 +1007,12 @@ class SparseFromDense(Op):
def __str__(self):
return f"{self.__class__.__name__}{{{self.format}}}"
def __call__(self, x):
if isinstance(x.type, SparseType):
return x
return super().__call__(x)
def make_node(self, x):
x = at.as_tensor_variable(x)
if x.ndim > 2:
......
......@@ -23,6 +23,7 @@ from aesara.tensor import blas
from aesara.tensor.basic import as_tensor_variable, cast, patternbroadcast
from aesara.tensor.basic_opt import register_canonicalize, register_specialize
from aesara.tensor.math import mul, neg, sub
from aesara.tensor.shape import shape, specify_shape
from aesara.tensor.type import TensorType, tensor
......@@ -2070,8 +2071,19 @@ def local_sampling_dot_csr(fgraph, node):
z_data, z_ind, z_ptr = sampling_dot_csr(
x, y, p_data, p_ind, p_ptr, p_shape[1]
)
return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)]
# This is a hack that works around some missing `Type`-related
# static shape narrowing. More specifically,
# `TensorType.convert_variable` currently won't combine the static
# shape information from `old_out.type` and `new_out.type`, only
# the broadcast patterns, and, since `CSR.make_node` doesn't do
# that either, we use `specify_shape` to produce an output `Type`
# with the same level of static shape information as the original
# `old_out`.
old_out = node.outputs[0]
new_out = specify_shape(
sparse.CSR(z_data, z_ind, z_ptr, p_shape), shape(old_out)
)
return [new_out]
return False
......
......@@ -2,7 +2,8 @@ import numpy as np
import scipy.sparse
import aesara
from aesara.graph.type import HasDataType, Type
from aesara.graph.type import HasDataType
from aesara.tensor.type import TensorType
def _is_sparse(x):
......@@ -24,7 +25,7 @@ def _is_sparse(x):
return isinstance(x, scipy.sparse.spmatrix)
class SparseType(Type, HasDataType):
class SparseType(TensorType, HasDataType):
"""
Fundamental way to create a sparse node.
......@@ -52,19 +53,19 @@ class SparseType(Type, HasDataType):
"csc": scipy.sparse.csc_matrix,
"bsr": scipy.sparse.bsr_matrix,
}
dtype_set = {
"int8",
"int16",
"int32",
"int64",
"float32",
"uint8",
"uint16",
"uint32",
"uint64",
"float64",
"complex64",
"complex128",
dtype_specs_map = {
"float32": (float, "npy_float32", "NPY_FLOAT32"),
"float64": (float, "npy_float64", "NPY_FLOAT64"),
"uint8": (int, "npy_uint8", "NPY_UINT8"),
"int8": (int, "npy_int8", "NPY_INT8"),
"uint16": (int, "npy_uint16", "NPY_UINT16"),
"int16": (int, "npy_int16", "NPY_INT16"),
"uint32": (int, "npy_uint32", "NPY_UINT32"),
"int32": (int, "npy_int32", "NPY_INT32"),
"uint64": (int, "npy_uint64", "NPY_UINT64"),
"int64": (int, "npy_int64", "NPY_INT64"),
"complex128": (complex, "aesara_complex128", "NPY_COMPLEX128"),
"complex64": (complex, "aesara_complex64", "NPY_COMPLEX64"),
}
ndim = 2
......@@ -72,28 +73,25 @@ class SparseType(Type, HasDataType):
variable_type = None
Constant = None
def __init__(self, format, dtype, shape=None):
dtype = str(dtype)
if dtype in self.dtype_set:
self.dtype = dtype
else:
raise NotImplementedError(
f'unsupported dtype "{dtype}" not in list', list(self.dtype_set)
)
def __init__(self, format, dtype, shape=None, broadcastable=None, name=None):
if shape is None:
shape = (None, None)
self.shape = shape
assert isinstance(format, str)
if not isinstance(format, str):
raise TypeError("The sparse format parameter must be a string")
if format in self.format_cls:
self.format = format
else:
raise NotImplementedError(
f'unsupported format "{format}" not in list',
list(self.format_cls.keys()),
)
if broadcastable is None:
broadcastable = [False, False]
super().__init__(dtype, shape, name=name)
def clone(self, format=None, dtype=None, shape=None, **kwargs):
if format is None:
......@@ -153,21 +151,11 @@ class SparseType(Type, HasDataType):
def make_variable(self, name=None):
return self.variable_type(self, name=name)
def __eq__(self, other):
return (
type(self) == type(other)
and other.dtype == self.dtype
and other.format == self.format
)
def __hash__(self):
return hash(self.dtype) ^ hash(self.format)
def __str__(self):
return f"Sparse[{self.dtype}, {self.format}]"
return super().__hash__() ^ hash(self.format)
def __repr__(self):
return f"Sparse[{self.dtype}, {self.format}]"
return f"Sparse({self.dtype}, {self.shape}, {self.format})"
def values_eq_approx(self, a, b, eps=1e-6):
# WARNING: equality comparison of sparse matrices is not fast or easy
......@@ -210,6 +198,31 @@ class SparseType(Type, HasDataType):
+ (shape_info[2] + shape_info[3]) * np.dtype("int32").itemsize
)
def value_zeros(self, shape):
matrix_constructor = self.format_cls.get(self.format)
if matrix_constructor is None:
raise ValueError(f"Sparse matrix type {self.format} not found in SciPy")
return matrix_constructor(shape, dtype=self.dtype)
def __eq__(self, other):
res = super().__eq__(other)
if isinstance(res, bool):
return res and other.format == self.format
return res
def is_super(self, otype):
if not super().is_super(otype):
return False
if self.format == otype.format:
return True
return False
# Register SparseType's C code for ViewOp.
aesara.compile.register_view_op_c_code(
......
......@@ -313,8 +313,13 @@ def get_scalar_constant_value(
return np.array(data.item(), dtype=v.dtype)
except ValueError:
raise NotScalarConstantError()
else:
return data
from aesara.sparse.type import SparseType
if isinstance(v.type, SparseType):
raise NotScalarConstantError()
return data
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
max_recur -= 1
......
......@@ -78,7 +78,12 @@ from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes
from aesara.tensor.type import (
DenseTensorType,
TensorType,
discrete_dtypes,
integer_dtypes,
)
from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter
......@@ -2954,7 +2959,8 @@ def constant_folding(fgraph, node):
# TODO: `Type` itself should provide an interface for constructing
# instances appropriate for a given constant.
if isinstance(output.type, TensorType):
# TODO: Add handling for sparse types.
if isinstance(output.type, DenseTensorType):
output_type = TensorType(
output.type.dtype,
tuple(s == 1 for s in data.shape),
......
......@@ -167,7 +167,12 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.type import integer_dtypes, tensor, values_eq_approx_remove_inf_nan
from aesara.tensor.type import (
DenseTensorType,
integer_dtypes,
tensor,
values_eq_approx_remove_inf_nan,
)
from aesara.utils import memoize
......@@ -264,7 +269,13 @@ class Gemv(Op):
raise TypeError("gemv requires vector for x", x.type)
if y.ndim != 1:
raise TypeError("gemv requires vector for y", y.type)
return Apply(self, [y, alpha, A, x, beta], [y.type()])
inputs = [y, alpha, A, x, beta]
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")
return Apply(self, inputs, [y.type()])
def perform(self, node, inputs, out_storage, params=None):
y, alpha, A, x, beta = inputs
......@@ -361,7 +372,12 @@ class Ger(Op):
if x.dtype not in ("float32", "float64", "complex64", "complex128"):
raise TypeError("only float and complex types supported", x.dtype)
return Apply(self, [A, alpha, x, y], [A.type()])
inputs = [A, alpha, x, y]
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")
return Apply(self, inputs, [A.type()])
def perform(self, node, inp, out, params=None):
cA, calpha, cx, cy = inp
......@@ -899,6 +915,10 @@ class Gemm(GemmRelated):
def make_node(self, *inputs):
inputs = list(map(at.as_tensor_variable, inputs))
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")
if len(inputs) != 5:
raise TypeError(
f"Wrong number of inputs for {self} (expected 5, got {len(inputs)})"
......@@ -1580,6 +1600,10 @@ class Dot22(GemmRelated):
def make_node(self, x, y):
x = at.as_tensor_variable(x)
y = at.as_tensor_variable(y)
if any(not isinstance(i.type, DenseTensorType) for i in (x, y)):
raise NotImplementedError("Only dense tensor types are supported")
dtypes = ("float16", "float32", "float64", "complex64", "complex128")
if x.type.ndim != 2 or x.type.dtype not in dtypes:
raise TypeError(x)
......@@ -1665,6 +1689,9 @@ def local_dot_to_dot22(fgraph, node):
if not isinstance(node.op, Dot):
return
if any(not isinstance(i.type, DenseTensorType) for i in node.inputs):
return False
x, y = node.inputs
if y.type.dtype != x.type.dtype:
# TODO: upcast one so the types match
......@@ -1869,6 +1896,10 @@ class Dot22Scalar(GemmRelated):
check_input = False
def make_node(self, x, y, a):
if any(not isinstance(i.type, DenseTensorType) for i in (x, y, a)):
raise NotImplementedError("Only dense tensor types are supported")
if a.ndim != 0:
raise TypeError(Gemm.E_scalar, a)
if x.ndim != 2:
......@@ -2089,6 +2120,9 @@ class BatchedDot(COp):
def make_node(self, *inputs):
inputs = list(map(at.as_tensor_variable, inputs))
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")
if len(inputs) != 2:
raise TypeError(f"Two arguments required, but {len(inputs)} given.")
if inputs[0].ndim not in (2, 3):
......
......@@ -34,6 +34,7 @@ from aesara.tensor.elemwise import (
)
from aesara.tensor.shape import shape
from aesara.tensor.type import (
DenseTensorType,
complex_dtypes,
continuous_dtypes,
discrete_dtypes,
......@@ -2076,6 +2077,11 @@ def dense_dot(a, b):
"""
a, b = as_tensor_variable(a), as_tensor_variable(b)
if not isinstance(a.type, DenseTensorType) or not isinstance(
b.type, DenseTensorType
):
raise TypeError("The dense dot product is only supported for dense types")
if a.ndim == 0 or b.ndim == 0:
return a * b
elif a.ndim > 2 or b.ndim > 2:
......
......@@ -431,7 +431,7 @@ class SpecifyShape(COp):
)
if isinstance(x.type, TensorType) and all(isinstance(s, Number) for s in shape):
out_var = TensorType(x.type.dtype, shape)()
out_var = x.type.clone(shape=shape)()
else:
out_var = x.type()
......
import logging
import warnings
from typing import Iterable, Optional, Union
from typing import Iterable, Optional, Tuple, Union
import numpy as np
......@@ -9,6 +9,7 @@ from aesara import scalar as aes
from aesara.configdefaults import config
from aesara.graph.basic import Variable
from aesara.graph.type import HasDataType
from aesara.graph.utils import MetaType
from aesara.link.c.type import CType
from aesara.misc.safe_asarray import _asarray
from aesara.utils import apply_across_args
......@@ -50,8 +51,9 @@ dtype_specs_map = {
class TensorType(CType, HasDataType):
r"""Symbolic `Type` representing `numpy.ndarray`\s."""
__props__ = ("dtype", "shape")
__props__: Tuple[str, ...] = ("dtype", "shape")
dtype_specs_map = dtype_specs_map
context_name = "cpu"
filter_checks_isfinite = False
"""
......@@ -271,7 +273,7 @@ class TensorType(CType, HasDataType):
"""
try:
return dtype_specs_map[self.dtype]
return self.dtype_specs_map[self.dtype]
except KeyError:
raise TypeError(
f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}"
......@@ -613,6 +615,20 @@ class TensorType(CType, HasDataType):
return ()
class DenseTypeMeta(MetaType):
def __instancecheck__(self, o):
if type(o) == TensorType or isinstance(o, DenseTypeMeta):
return True
return False
class DenseTensorType(TensorType, metaclass=DenseTypeMeta):
r"""A `Type` for dense tensors.
Instances of this class and `TensorType`\s are considered dense `Type`\s.
"""
def values_eq_approx(
a, b, allow_remove_inf=False, allow_remove_nan=False, rtol=None, atol=None
):
......
......@@ -10,6 +10,7 @@ import numpy as np
from aesara import tensor as at
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.utils import MetaType
from aesara.scalar import ComplexError, IntegerDivisionError
from aesara.tensor import _get_vector_length, as_tensor_variable
from aesara.tensor.exceptions import AdvancedIndexingError
......@@ -1040,3 +1041,33 @@ class TensorConstant(TensorVariable, Constant):
TensorType.constant_type = TensorConstant
class DenseVariableMeta(MetaType):
def __instancecheck__(self, o):
if type(o) == TensorVariable or isinstance(o, DenseVariableMeta):
return True
return False
class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta):
r"""A `Variable` for dense tensors.
Instances of this class and `TensorVariable`\s are considered dense
`Variable`\s.
"""
class DenseConstantMeta(MetaType):
def __instancecheck__(self, o):
if type(o) == TensorConstant or isinstance(o, DenseConstantMeta):
return True
return False
class DenseTensorConstant(TensorType, metaclass=DenseConstantMeta):
r"""A `Constant` for dense tensors.
Instances of this class and `TensorConstant`\s are considered dense
`Constant`\s.
"""
......@@ -12,7 +12,7 @@ from aesara.compile.function import function
from aesara.compile.io import In, Out
from aesara.configdefaults import config
from aesara.gradient import GradientError
from aesara.graph.basic import Apply, Constant
from aesara.graph.basic import Apply, Constant, applys_between
from aesara.graph.op import Op
from aesara.misc.safe_asarray import _asarray
from aesara.sparse import (
......@@ -78,6 +78,7 @@ from aesara.sparse import (
true_dot,
)
from aesara.sparse.basic import (
SparseConstant,
_is_dense_variable,
_is_sparse,
_is_sparse_variable,
......@@ -1017,22 +1018,45 @@ class TestComparison:
class TestConversion:
@pytest.mark.skip
def test_basic(self):
a = at.as_tensor_variable(np.random.random((5)))
test_val = np.random.rand(5).astype(config.floatX)
a = at.as_tensor_variable(test_val)
s = csc_from_dense(a)
val = eval_outputs([s])
assert str(val.dtype) == "float64"
assert str(val.dtype) == config.floatX
assert val.format == "csc"
@pytest.mark.skip
def test_basic_1(self):
a = at.as_tensor_variable(np.random.random((5)))
a = at.as_tensor_variable(test_val)
s = csr_from_dense(a)
val = eval_outputs([s])
assert str(val.dtype) == "float64"
assert str(val.dtype) == config.floatX
assert val.format == "csr"
test_val = np.eye(3).astype(config.floatX)
a = sp.sparse.csr_matrix(test_val)
s = as_sparse_or_tensor_variable(a)
res = at.as_tensor_variable(s)
assert isinstance(res, SparseConstant)
a = sp.sparse.csr_matrix(test_val)
s = as_sparse_or_tensor_variable(a)
from aesara.tensor.exceptions import NotScalarConstantError
with pytest.raises(NotScalarConstantError):
at.get_scalar_constant_value(s, only_process_constants=True)
# TODO:
# def test_sparse_as_tensor_variable(self):
# csr = sp.sparse.csr_matrix(np.eye(3))
# val = aet.as_tensor_variable(csr)
# assert str(val.dtype) == config.floatX
# assert val.format == "csr"
#
# csr = sp.sparse.csc_matrix(np.eye(3))
# val = aet.as_tensor_variable(csr)
# assert str(val.dtype) == config.floatX
# assert val.format == "csc"
def test_dense_from_sparse(self):
# call dense_from_sparse
for t in _mtypes:
......@@ -1591,6 +1615,32 @@ class TestDots(utt.InferShapeTester):
)
f(i, a)
def test_tensor_dot_types(self):
x = sparse.csc_matrix("x")
x_d = at.matrix("x_d")
y = sparse.csc_matrix("y")
res = at.dot(x, y)
op_types = set(type(n.op) for n in applys_between([x, y], [res]))
assert sparse.basic.StructuredDot in op_types
assert at.math.Dot not in op_types
res = at.dot(x_d, y)
op_types = set(type(n.op) for n in applys_between([x, y], [res]))
assert sparse.basic.StructuredDot in op_types
assert at.math.Dot not in op_types
res = at.dot(x, x_d)
op_types = set(type(n.op) for n in applys_between([x, y], [res]))
assert sparse.basic.StructuredDot in op_types
assert at.math.Dot not in op_types
res = at.dot(at.second(1, x), y)
op_types = set(type(n.op) for n in applys_between([x, y], [res]))
assert sparse.basic.StructuredDot in op_types
assert at.math.Dot not in op_types
def test_csr_dense_grad(self):
# shortcut: testing csc in float32, testing csr in float64
......
import pytest
from aesara.sparse import matrix as sp_matrix
from aesara.sparse.type import SparseType
from aesara.tensor import dmatrix
def test_clone():
st = SparseType("csr", "float64")
assert st == st.clone()
def test_Sparse_convert_variable():
x = dmatrix(name="x")
y = sp_matrix("csc", dtype="float64", name="y")
z = sp_matrix("csr", dtype="float64", name="z")
assert y.type.convert_variable(z) is None
# TODO FIXME: This is a questionable result, because `x.type` is associated
# with a dense `Type`, but, since `TensorType` is a base class of `Sparse`,
# we would need to added sparse/dense logic to `TensorType`, and we don't
# want to do that.
assert x.type.convert_variable(y) is y
# TODO FIXME: We should be able to do this.
with pytest.raises(NotImplementedError):
y.type.convert_variable(x)
......@@ -6,6 +6,7 @@ import aesara
import tests.unittest_tools as utt
from aesara.graph.basic import Constant, equal_computations
from aesara.tensor import get_vector_length
from aesara.tensor.basic import constant
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import dot
from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor
......@@ -21,7 +22,12 @@ from aesara.tensor.type import (
tensor3,
)
from aesara.tensor.type_other import MakeSlice
from aesara.tensor.var import TensorConstant, TensorVariable
from aesara.tensor.var import (
DenseTensorConstant,
DenseTensorVariable,
TensorConstant,
TensorVariable,
)
@pytest.mark.parametrize(
......@@ -247,3 +253,14 @@ def test_get_vector_length():
x = TensorVariable(TensorType("int64", (None,)))
with pytest.raises(ValueError):
get_vector_length(x)
def test_dense_types():
x = matrix()
assert isinstance(x, DenseTensorVariable)
assert not isinstance(x, DenseTensorConstant)
x = constant(1)
assert not isinstance(x, DenseTensorVariable)
assert isinstance(x, DenseTensorConstant)
......@@ -335,13 +335,13 @@ class TestIfelse(utt.OptimizationTestMixin):
z = aesara.sparse.matrix("csr", dtype=self.dtype, name="z")
cond = iscalar("cond")
with pytest.raises(TypeError):
with pytest.raises(NotImplementedError):
ifelse(cond, x, y)
with pytest.raises(TypeError):
with pytest.raises(NotImplementedError):
ifelse(cond, y, x)
with pytest.raises(TypeError):
with pytest.raises(NotImplementedError):
ifelse(cond, x, z)
with pytest.raises(TypeError):
with pytest.raises(NotImplementedError):
ifelse(cond, z, x)
with pytest.raises(TypeError):
ifelse(cond, y, z)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论