提交 fdf2a23a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make sparse tensor types extend the existing tensor types

上级 31ab8fdd
...@@ -49,6 +49,7 @@ from aesara.tensor.type import TensorType ...@@ -49,6 +49,7 @@ from aesara.tensor.type import TensorType
from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes
from aesara.tensor.type import iscalar, ivector, scalar, tensor, vector from aesara.tensor.type import iscalar, ivector, scalar, tensor, vector
from aesara.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators
sparse_formats = ["csc", "csr"] sparse_formats = ["csc", "csr"]
...@@ -126,8 +127,7 @@ def _is_dense(x): ...@@ -126,8 +127,7 @@ def _is_dense(x):
return isinstance(x, np.ndarray) return isinstance(x, np.ndarray)
# Wrapper type def as_sparse_variable(x, name=None, ndim=None, **kwargs):
def as_sparse_variable(x, name=None):
""" """
Wrapper around SparseVariable constructor to construct Wrapper around SparseVariable constructor to construct
a Variable with a sparse matrix with the same dtype and a Variable with a sparse matrix with the same dtype and
...@@ -250,7 +250,7 @@ def sp_zeros_like(x): ...@@ -250,7 +250,7 @@ def sp_zeros_like(x):
) )
class _sparse_py_operators: class _sparse_py_operators(_tensor_py_operators):
T = property( T = property(
lambda self: transpose(self), doc="Return aliased transpose of self (read-only)" lambda self: transpose(self), doc="Return aliased transpose of self (read-only)"
) )
...@@ -361,8 +361,7 @@ class _sparse_py_operators: ...@@ -361,8 +361,7 @@ class _sparse_py_operators:
return ret return ret
class SparseVariable(_sparse_py_operators, Variable): class SparseVariable(_sparse_py_operators, TensorVariable):
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format) format = property(lambda self: self.type.format)
def __str__(self): def __str__(self):
...@@ -395,8 +394,7 @@ class SparseConstantSignature(tuple): ...@@ -395,8 +394,7 @@ class SparseConstantSignature(tuple):
return hash_from_sparse(d) return hash_from_sparse(d)
class SparseConstant(Constant, _sparse_py_operators): class SparseConstant(TensorConstant, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format) format = property(lambda self: self.type.format)
def signature(self): def signature(self):
...@@ -448,7 +446,7 @@ csc_fmatrix = SparseType(format="csc", dtype="float32") ...@@ -448,7 +446,7 @@ csc_fmatrix = SparseType(format="csc", dtype="float32")
csr_fmatrix = SparseType(format="csr", dtype="float32") csr_fmatrix = SparseType(format="csr", dtype="float32")
bsr_fmatrix = SparseType(format="bsr", dtype="float32") bsr_fmatrix = SparseType(format="bsr", dtype="float32")
all_dtypes = SparseType.dtype_set all_dtypes = list(SparseType.dtype_specs_map.keys())
complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"] complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"]
float_dtypes = [t for t in all_dtypes if t[:5] == "float"] float_dtypes = [t for t in all_dtypes if t[:5] == "float"]
int_dtypes = [t for t in all_dtypes if t[:3] == "int"] int_dtypes = [t for t in all_dtypes if t[:3] == "int"]
...@@ -926,6 +924,12 @@ class DenseFromSparse(Op): ...@@ -926,6 +924,12 @@ class DenseFromSparse(Op):
def __str__(self): def __str__(self):
return f"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}" return f"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}"
def __call__(self, x):
if not isinstance(x.type, SparseType):
return x
return super().__call__(x)
def make_node(self, x): def make_node(self, x):
x = as_sparse_variable(x) x = as_sparse_variable(x)
return Apply( return Apply(
...@@ -1003,6 +1007,12 @@ class SparseFromDense(Op): ...@@ -1003,6 +1007,12 @@ class SparseFromDense(Op):
def __str__(self): def __str__(self):
return f"{self.__class__.__name__}{{{self.format}}}" return f"{self.__class__.__name__}{{{self.format}}}"
def __call__(self, x):
if isinstance(x.type, SparseType):
return x
return super().__call__(x)
def make_node(self, x): def make_node(self, x):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
if x.ndim > 2: if x.ndim > 2:
......
...@@ -23,6 +23,7 @@ from aesara.tensor import blas ...@@ -23,6 +23,7 @@ from aesara.tensor import blas
from aesara.tensor.basic import as_tensor_variable, cast, patternbroadcast from aesara.tensor.basic import as_tensor_variable, cast, patternbroadcast
from aesara.tensor.basic_opt import register_canonicalize, register_specialize from aesara.tensor.basic_opt import register_canonicalize, register_specialize
from aesara.tensor.math import mul, neg, sub from aesara.tensor.math import mul, neg, sub
from aesara.tensor.shape import shape, specify_shape
from aesara.tensor.type import TensorType, tensor from aesara.tensor.type import TensorType, tensor
...@@ -2070,8 +2071,19 @@ def local_sampling_dot_csr(fgraph, node): ...@@ -2070,8 +2071,19 @@ def local_sampling_dot_csr(fgraph, node):
z_data, z_ind, z_ptr = sampling_dot_csr( z_data, z_ind, z_ptr = sampling_dot_csr(
x, y, p_data, p_ind, p_ptr, p_shape[1] x, y, p_data, p_ind, p_ptr, p_shape[1]
) )
# This is a hack that works around some missing `Type`-related
return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)] # static shape narrowing. More specifically,
# `TensorType.convert_variable` currently won't combine the static
# shape information from `old_out.type` and `new_out.type`, only
# the broadcast patterns, and, since `CSR.make_node` doesn't do
# that either, we use `specify_shape` to produce an output `Type`
# with the same level of static shape information as the original
# `old_out`.
old_out = node.outputs[0]
new_out = specify_shape(
sparse.CSR(z_data, z_ind, z_ptr, p_shape), shape(old_out)
)
return [new_out]
return False return False
......
...@@ -2,7 +2,8 @@ import numpy as np ...@@ -2,7 +2,8 @@ import numpy as np
import scipy.sparse import scipy.sparse
import aesara import aesara
from aesara.graph.type import HasDataType, Type from aesara.graph.type import HasDataType
from aesara.tensor.type import TensorType
def _is_sparse(x): def _is_sparse(x):
...@@ -24,7 +25,7 @@ def _is_sparse(x): ...@@ -24,7 +25,7 @@ def _is_sparse(x):
return isinstance(x, scipy.sparse.spmatrix) return isinstance(x, scipy.sparse.spmatrix)
class SparseType(Type, HasDataType): class SparseType(TensorType, HasDataType):
""" """
Fundamental way to create a sparse node. Fundamental way to create a sparse node.
...@@ -52,19 +53,19 @@ class SparseType(Type, HasDataType): ...@@ -52,19 +53,19 @@ class SparseType(Type, HasDataType):
"csc": scipy.sparse.csc_matrix, "csc": scipy.sparse.csc_matrix,
"bsr": scipy.sparse.bsr_matrix, "bsr": scipy.sparse.bsr_matrix,
} }
dtype_set = { dtype_specs_map = {
"int8", "float32": (float, "npy_float32", "NPY_FLOAT32"),
"int16", "float64": (float, "npy_float64", "NPY_FLOAT64"),
"int32", "uint8": (int, "npy_uint8", "NPY_UINT8"),
"int64", "int8": (int, "npy_int8", "NPY_INT8"),
"float32", "uint16": (int, "npy_uint16", "NPY_UINT16"),
"uint8", "int16": (int, "npy_int16", "NPY_INT16"),
"uint16", "uint32": (int, "npy_uint32", "NPY_UINT32"),
"uint32", "int32": (int, "npy_int32", "NPY_INT32"),
"uint64", "uint64": (int, "npy_uint64", "NPY_UINT64"),
"float64", "int64": (int, "npy_int64", "NPY_INT64"),
"complex64", "complex128": (complex, "aesara_complex128", "NPY_COMPLEX128"),
"complex128", "complex64": (complex, "aesara_complex64", "NPY_COMPLEX64"),
} }
ndim = 2 ndim = 2
...@@ -72,28 +73,25 @@ class SparseType(Type, HasDataType): ...@@ -72,28 +73,25 @@ class SparseType(Type, HasDataType):
variable_type = None variable_type = None
Constant = None Constant = None
def __init__(self, format, dtype, shape=None): def __init__(self, format, dtype, shape=None, broadcastable=None, name=None):
dtype = str(dtype)
if dtype in self.dtype_set:
self.dtype = dtype
else:
raise NotImplementedError(
f'unsupported dtype "{dtype}" not in list', list(self.dtype_set)
)
if shape is None: if shape is None:
shape = (None, None) shape = (None, None)
self.shape = shape self.shape = shape
assert isinstance(format, str) if not isinstance(format, str):
raise TypeError("The sparse format parameter must be a string")
if format in self.format_cls: if format in self.format_cls:
self.format = format self.format = format
else: else:
raise NotImplementedError( raise NotImplementedError(
f'unsupported format "{format}" not in list', f'unsupported format "{format}" not in list',
list(self.format_cls.keys()),
) )
if broadcastable is None:
broadcastable = [False, False]
super().__init__(dtype, shape, name=name)
def clone(self, format=None, dtype=None, shape=None, **kwargs): def clone(self, format=None, dtype=None, shape=None, **kwargs):
if format is None: if format is None:
...@@ -153,21 +151,11 @@ class SparseType(Type, HasDataType): ...@@ -153,21 +151,11 @@ class SparseType(Type, HasDataType):
def make_variable(self, name=None): def make_variable(self, name=None):
return self.variable_type(self, name=name) return self.variable_type(self, name=name)
def __eq__(self, other):
return (
type(self) == type(other)
and other.dtype == self.dtype
and other.format == self.format
)
def __hash__(self): def __hash__(self):
return hash(self.dtype) ^ hash(self.format) return super().__hash__() ^ hash(self.format)
def __str__(self):
return f"Sparse[{self.dtype}, {self.format}]"
def __repr__(self): def __repr__(self):
return f"Sparse[{self.dtype}, {self.format}]" return f"Sparse({self.dtype}, {self.shape}, {self.format})"
def values_eq_approx(self, a, b, eps=1e-6): def values_eq_approx(self, a, b, eps=1e-6):
# WARNING: equality comparison of sparse matrices is not fast or easy # WARNING: equality comparison of sparse matrices is not fast or easy
...@@ -210,6 +198,31 @@ class SparseType(Type, HasDataType): ...@@ -210,6 +198,31 @@ class SparseType(Type, HasDataType):
+ (shape_info[2] + shape_info[3]) * np.dtype("int32").itemsize + (shape_info[2] + shape_info[3]) * np.dtype("int32").itemsize
) )
def value_zeros(self, shape):
matrix_constructor = self.format_cls.get(self.format)
if matrix_constructor is None:
raise ValueError(f"Sparse matrix type {self.format} not found in SciPy")
return matrix_constructor(shape, dtype=self.dtype)
def __eq__(self, other):
res = super().__eq__(other)
if isinstance(res, bool):
return res and other.format == self.format
return res
def is_super(self, otype):
if not super().is_super(otype):
return False
if self.format == otype.format:
return True
return False
# Register SparseType's C code for ViewOp. # Register SparseType's C code for ViewOp.
aesara.compile.register_view_op_c_code( aesara.compile.register_view_op_c_code(
......
...@@ -313,8 +313,13 @@ def get_scalar_constant_value( ...@@ -313,8 +313,13 @@ def get_scalar_constant_value(
return np.array(data.item(), dtype=v.dtype) return np.array(data.item(), dtype=v.dtype)
except ValueError: except ValueError:
raise NotScalarConstantError() raise NotScalarConstantError()
else:
return data from aesara.sparse.type import SparseType
if isinstance(v.type, SparseType):
raise NotScalarConstantError()
return data
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
max_recur -= 1 max_recur -= 1
......
...@@ -78,7 +78,12 @@ from aesara.tensor.math import eq ...@@ -78,7 +78,12 @@ from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
from aesara.tensor.sort import TopKOp from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes from aesara.tensor.type import (
DenseTensorType,
TensorType,
discrete_dtypes,
integer_dtypes,
)
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter from aesara.utils import NoDuplicateOptWarningFilter
...@@ -2954,7 +2959,8 @@ def constant_folding(fgraph, node): ...@@ -2954,7 +2959,8 @@ def constant_folding(fgraph, node):
# TODO: `Type` itself should provide an interface for constructing # TODO: `Type` itself should provide an interface for constructing
# instances appropriate for a given constant. # instances appropriate for a given constant.
if isinstance(output.type, TensorType): # TODO: Add handling for sparse types.
if isinstance(output.type, DenseTensorType):
output_type = TensorType( output_type = TensorType(
output.type.dtype, output.type.dtype,
tuple(s == 1 for s in data.shape), tuple(s == 1 for s in data.shape),
......
...@@ -167,7 +167,12 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version ...@@ -167,7 +167,12 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.type import integer_dtypes, tensor, values_eq_approx_remove_inf_nan from aesara.tensor.type import (
DenseTensorType,
integer_dtypes,
tensor,
values_eq_approx_remove_inf_nan,
)
from aesara.utils import memoize from aesara.utils import memoize
...@@ -264,7 +269,13 @@ class Gemv(Op): ...@@ -264,7 +269,13 @@ class Gemv(Op):
raise TypeError("gemv requires vector for x", x.type) raise TypeError("gemv requires vector for x", x.type)
if y.ndim != 1: if y.ndim != 1:
raise TypeError("gemv requires vector for y", y.type) raise TypeError("gemv requires vector for y", y.type)
return Apply(self, [y, alpha, A, x, beta], [y.type()])
inputs = [y, alpha, A, x, beta]
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")
return Apply(self, inputs, [y.type()])
def perform(self, node, inputs, out_storage, params=None): def perform(self, node, inputs, out_storage, params=None):
y, alpha, A, x, beta = inputs y, alpha, A, x, beta = inputs
...@@ -361,7 +372,12 @@ class Ger(Op): ...@@ -361,7 +372,12 @@ class Ger(Op):
if x.dtype not in ("float32", "float64", "complex64", "complex128"): if x.dtype not in ("float32", "float64", "complex64", "complex128"):
raise TypeError("only float and complex types supported", x.dtype) raise TypeError("only float and complex types supported", x.dtype)
return Apply(self, [A, alpha, x, y], [A.type()])
inputs = [A, alpha, x, y]
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")
return Apply(self, inputs, [A.type()])
def perform(self, node, inp, out, params=None): def perform(self, node, inp, out, params=None):
cA, calpha, cx, cy = inp cA, calpha, cx, cy = inp
...@@ -899,6 +915,10 @@ class Gemm(GemmRelated): ...@@ -899,6 +915,10 @@ class Gemm(GemmRelated):
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(at.as_tensor_variable, inputs)) inputs = list(map(at.as_tensor_variable, inputs))
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")
if len(inputs) != 5: if len(inputs) != 5:
raise TypeError( raise TypeError(
f"Wrong number of inputs for {self} (expected 5, got {len(inputs)})" f"Wrong number of inputs for {self} (expected 5, got {len(inputs)})"
...@@ -1580,6 +1600,10 @@ class Dot22(GemmRelated): ...@@ -1580,6 +1600,10 @@ class Dot22(GemmRelated):
def make_node(self, x, y): def make_node(self, x, y):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
y = at.as_tensor_variable(y) y = at.as_tensor_variable(y)
if any(not isinstance(i.type, DenseTensorType) for i in (x, y)):
raise NotImplementedError("Only dense tensor types are supported")
dtypes = ("float16", "float32", "float64", "complex64", "complex128") dtypes = ("float16", "float32", "float64", "complex64", "complex128")
if x.type.ndim != 2 or x.type.dtype not in dtypes: if x.type.ndim != 2 or x.type.dtype not in dtypes:
raise TypeError(x) raise TypeError(x)
...@@ -1665,6 +1689,9 @@ def local_dot_to_dot22(fgraph, node): ...@@ -1665,6 +1689,9 @@ def local_dot_to_dot22(fgraph, node):
if not isinstance(node.op, Dot): if not isinstance(node.op, Dot):
return return
if any(not isinstance(i.type, DenseTensorType) for i in node.inputs):
return False
x, y = node.inputs x, y = node.inputs
if y.type.dtype != x.type.dtype: if y.type.dtype != x.type.dtype:
# TODO: upcast one so the types match # TODO: upcast one so the types match
...@@ -1869,6 +1896,10 @@ class Dot22Scalar(GemmRelated): ...@@ -1869,6 +1896,10 @@ class Dot22Scalar(GemmRelated):
check_input = False check_input = False
def make_node(self, x, y, a): def make_node(self, x, y, a):
if any(not isinstance(i.type, DenseTensorType) for i in (x, y, a)):
raise NotImplementedError("Only dense tensor types are supported")
if a.ndim != 0: if a.ndim != 0:
raise TypeError(Gemm.E_scalar, a) raise TypeError(Gemm.E_scalar, a)
if x.ndim != 2: if x.ndim != 2:
...@@ -2089,6 +2120,9 @@ class BatchedDot(COp): ...@@ -2089,6 +2120,9 @@ class BatchedDot(COp):
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(at.as_tensor_variable, inputs)) inputs = list(map(at.as_tensor_variable, inputs))
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")
if len(inputs) != 2: if len(inputs) != 2:
raise TypeError(f"Two arguments required, but {len(inputs)} given.") raise TypeError(f"Two arguments required, but {len(inputs)} given.")
if inputs[0].ndim not in (2, 3): if inputs[0].ndim not in (2, 3):
......
...@@ -34,6 +34,7 @@ from aesara.tensor.elemwise import ( ...@@ -34,6 +34,7 @@ from aesara.tensor.elemwise import (
) )
from aesara.tensor.shape import shape from aesara.tensor.shape import shape
from aesara.tensor.type import ( from aesara.tensor.type import (
DenseTensorType,
complex_dtypes, complex_dtypes,
continuous_dtypes, continuous_dtypes,
discrete_dtypes, discrete_dtypes,
...@@ -2076,6 +2077,11 @@ def dense_dot(a, b): ...@@ -2076,6 +2077,11 @@ def dense_dot(a, b):
""" """
a, b = as_tensor_variable(a), as_tensor_variable(b) a, b = as_tensor_variable(a), as_tensor_variable(b)
if not isinstance(a.type, DenseTensorType) or not isinstance(
b.type, DenseTensorType
):
raise TypeError("The dense dot product is only supported for dense types")
if a.ndim == 0 or b.ndim == 0: if a.ndim == 0 or b.ndim == 0:
return a * b return a * b
elif a.ndim > 2 or b.ndim > 2: elif a.ndim > 2 or b.ndim > 2:
......
...@@ -431,7 +431,7 @@ class SpecifyShape(COp): ...@@ -431,7 +431,7 @@ class SpecifyShape(COp):
) )
if isinstance(x.type, TensorType) and all(isinstance(s, Number) for s in shape): if isinstance(x.type, TensorType) and all(isinstance(s, Number) for s in shape):
out_var = TensorType(x.type.dtype, shape)() out_var = x.type.clone(shape=shape)()
else: else:
out_var = x.type() out_var = x.type()
......
import logging import logging
import warnings import warnings
from typing import Iterable, Optional, Union from typing import Iterable, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -9,6 +9,7 @@ from aesara import scalar as aes ...@@ -9,6 +9,7 @@ from aesara import scalar as aes
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Variable from aesara.graph.basic import Variable
from aesara.graph.type import HasDataType from aesara.graph.type import HasDataType
from aesara.graph.utils import MetaType
from aesara.link.c.type import CType from aesara.link.c.type import CType
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.utils import apply_across_args from aesara.utils import apply_across_args
...@@ -50,8 +51,9 @@ dtype_specs_map = { ...@@ -50,8 +51,9 @@ dtype_specs_map = {
class TensorType(CType, HasDataType): class TensorType(CType, HasDataType):
r"""Symbolic `Type` representing `numpy.ndarray`\s.""" r"""Symbolic `Type` representing `numpy.ndarray`\s."""
__props__ = ("dtype", "shape") __props__: Tuple[str, ...] = ("dtype", "shape")
dtype_specs_map = dtype_specs_map
context_name = "cpu" context_name = "cpu"
filter_checks_isfinite = False filter_checks_isfinite = False
""" """
...@@ -271,7 +273,7 @@ class TensorType(CType, HasDataType): ...@@ -271,7 +273,7 @@ class TensorType(CType, HasDataType):
""" """
try: try:
return dtype_specs_map[self.dtype] return self.dtype_specs_map[self.dtype]
except KeyError: except KeyError:
raise TypeError( raise TypeError(
f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}" f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}"
...@@ -613,6 +615,20 @@ class TensorType(CType, HasDataType): ...@@ -613,6 +615,20 @@ class TensorType(CType, HasDataType):
return () return ()
class DenseTypeMeta(MetaType):
def __instancecheck__(self, o):
if type(o) == TensorType or isinstance(o, DenseTypeMeta):
return True
return False
class DenseTensorType(TensorType, metaclass=DenseTypeMeta):
r"""A `Type` for dense tensors.
Instances of this class and `TensorType`\s are considered dense `Type`\s.
"""
def values_eq_approx( def values_eq_approx(
a, b, allow_remove_inf=False, allow_remove_nan=False, rtol=None, atol=None a, b, allow_remove_inf=False, allow_remove_nan=False, rtol=None, atol=None
): ):
......
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
from aesara import tensor as at from aesara import tensor as at
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.utils import MetaType
from aesara.scalar import ComplexError, IntegerDivisionError from aesara.scalar import ComplexError, IntegerDivisionError
from aesara.tensor import _get_vector_length, as_tensor_variable from aesara.tensor import _get_vector_length, as_tensor_variable
from aesara.tensor.exceptions import AdvancedIndexingError from aesara.tensor.exceptions import AdvancedIndexingError
...@@ -1040,3 +1041,33 @@ class TensorConstant(TensorVariable, Constant): ...@@ -1040,3 +1041,33 @@ class TensorConstant(TensorVariable, Constant):
TensorType.constant_type = TensorConstant TensorType.constant_type = TensorConstant
class DenseVariableMeta(MetaType):
def __instancecheck__(self, o):
if type(o) == TensorVariable or isinstance(o, DenseVariableMeta):
return True
return False
class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta):
r"""A `Variable` for dense tensors.
Instances of this class and `TensorVariable`\s are considered dense
`Variable`\s.
"""
class DenseConstantMeta(MetaType):
def __instancecheck__(self, o):
if type(o) == TensorConstant or isinstance(o, DenseConstantMeta):
return True
return False
class DenseTensorConstant(TensorType, metaclass=DenseConstantMeta):
r"""A `Constant` for dense tensors.
Instances of this class and `TensorConstant`\s are considered dense
`Constant`\s.
"""
...@@ -12,7 +12,7 @@ from aesara.compile.function import function ...@@ -12,7 +12,7 @@ from aesara.compile.function import function
from aesara.compile.io import In, Out from aesara.compile.io import In, Out
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import GradientError from aesara.gradient import GradientError
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant, applys_between
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.sparse import ( from aesara.sparse import (
...@@ -78,6 +78,7 @@ from aesara.sparse import ( ...@@ -78,6 +78,7 @@ from aesara.sparse import (
true_dot, true_dot,
) )
from aesara.sparse.basic import ( from aesara.sparse.basic import (
SparseConstant,
_is_dense_variable, _is_dense_variable,
_is_sparse, _is_sparse,
_is_sparse_variable, _is_sparse_variable,
...@@ -1017,22 +1018,45 @@ class TestComparison: ...@@ -1017,22 +1018,45 @@ class TestComparison:
class TestConversion: class TestConversion:
@pytest.mark.skip
def test_basic(self): def test_basic(self):
a = at.as_tensor_variable(np.random.random((5))) test_val = np.random.rand(5).astype(config.floatX)
a = at.as_tensor_variable(test_val)
s = csc_from_dense(a) s = csc_from_dense(a)
val = eval_outputs([s]) val = eval_outputs([s])
assert str(val.dtype) == "float64" assert str(val.dtype) == config.floatX
assert val.format == "csc" assert val.format == "csc"
@pytest.mark.skip a = at.as_tensor_variable(test_val)
def test_basic_1(self):
a = at.as_tensor_variable(np.random.random((5)))
s = csr_from_dense(a) s = csr_from_dense(a)
val = eval_outputs([s]) val = eval_outputs([s])
assert str(val.dtype) == "float64" assert str(val.dtype) == config.floatX
assert val.format == "csr" assert val.format == "csr"
test_val = np.eye(3).astype(config.floatX)
a = sp.sparse.csr_matrix(test_val)
s = as_sparse_or_tensor_variable(a)
res = at.as_tensor_variable(s)
assert isinstance(res, SparseConstant)
a = sp.sparse.csr_matrix(test_val)
s = as_sparse_or_tensor_variable(a)
from aesara.tensor.exceptions import NotScalarConstantError
with pytest.raises(NotScalarConstantError):
at.get_scalar_constant_value(s, only_process_constants=True)
# TODO:
# def test_sparse_as_tensor_variable(self):
# csr = sp.sparse.csr_matrix(np.eye(3))
# val = aet.as_tensor_variable(csr)
# assert str(val.dtype) == config.floatX
# assert val.format == "csr"
#
# csr = sp.sparse.csc_matrix(np.eye(3))
# val = aet.as_tensor_variable(csr)
# assert str(val.dtype) == config.floatX
# assert val.format == "csc"
def test_dense_from_sparse(self): def test_dense_from_sparse(self):
# call dense_from_sparse # call dense_from_sparse
for t in _mtypes: for t in _mtypes:
...@@ -1591,6 +1615,32 @@ class TestDots(utt.InferShapeTester): ...@@ -1591,6 +1615,32 @@ class TestDots(utt.InferShapeTester):
) )
f(i, a) f(i, a)
def test_tensor_dot_types(self):
x = sparse.csc_matrix("x")
x_d = at.matrix("x_d")
y = sparse.csc_matrix("y")
res = at.dot(x, y)
op_types = set(type(n.op) for n in applys_between([x, y], [res]))
assert sparse.basic.StructuredDot in op_types
assert at.math.Dot not in op_types
res = at.dot(x_d, y)
op_types = set(type(n.op) for n in applys_between([x, y], [res]))
assert sparse.basic.StructuredDot in op_types
assert at.math.Dot not in op_types
res = at.dot(x, x_d)
op_types = set(type(n.op) for n in applys_between([x, y], [res]))
assert sparse.basic.StructuredDot in op_types
assert at.math.Dot not in op_types
res = at.dot(at.second(1, x), y)
op_types = set(type(n.op) for n in applys_between([x, y], [res]))
assert sparse.basic.StructuredDot in op_types
assert at.math.Dot not in op_types
def test_csr_dense_grad(self): def test_csr_dense_grad(self):
# shortcut: testing csc in float32, testing csr in float64 # shortcut: testing csc in float32, testing csr in float64
......
import pytest
from aesara.sparse import matrix as sp_matrix
from aesara.sparse.type import SparseType from aesara.sparse.type import SparseType
from aesara.tensor import dmatrix
def test_clone(): def test_clone():
st = SparseType("csr", "float64") st = SparseType("csr", "float64")
assert st == st.clone() assert st == st.clone()
def test_Sparse_convert_variable():
x = dmatrix(name="x")
y = sp_matrix("csc", dtype="float64", name="y")
z = sp_matrix("csr", dtype="float64", name="z")
assert y.type.convert_variable(z) is None
# TODO FIXME: This is a questionable result, because `x.type` is associated
# with a dense `Type`, but, since `TensorType` is a base class of `Sparse`,
# we would need to added sparse/dense logic to `TensorType`, and we don't
# want to do that.
assert x.type.convert_variable(y) is y
# TODO FIXME: We should be able to do this.
with pytest.raises(NotImplementedError):
y.type.convert_variable(x)
...@@ -6,6 +6,7 @@ import aesara ...@@ -6,6 +6,7 @@ import aesara
import tests.unittest_tools as utt import tests.unittest_tools as utt
from aesara.graph.basic import Constant, equal_computations from aesara.graph.basic import Constant, equal_computations
from aesara.tensor import get_vector_length from aesara.tensor import get_vector_length
from aesara.tensor.basic import constant
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import dot from aesara.tensor.math import dot
from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor
...@@ -21,7 +22,12 @@ from aesara.tensor.type import ( ...@@ -21,7 +22,12 @@ from aesara.tensor.type import (
tensor3, tensor3,
) )
from aesara.tensor.type_other import MakeSlice from aesara.tensor.type_other import MakeSlice
from aesara.tensor.var import TensorConstant, TensorVariable from aesara.tensor.var import (
DenseTensorConstant,
DenseTensorVariable,
TensorConstant,
TensorVariable,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -247,3 +253,14 @@ def test_get_vector_length(): ...@@ -247,3 +253,14 @@ def test_get_vector_length():
x = TensorVariable(TensorType("int64", (None,))) x = TensorVariable(TensorType("int64", (None,)))
with pytest.raises(ValueError): with pytest.raises(ValueError):
get_vector_length(x) get_vector_length(x)
def test_dense_types():
x = matrix()
assert isinstance(x, DenseTensorVariable)
assert not isinstance(x, DenseTensorConstant)
x = constant(1)
assert not isinstance(x, DenseTensorVariable)
assert isinstance(x, DenseTensorConstant)
...@@ -335,13 +335,13 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -335,13 +335,13 @@ class TestIfelse(utt.OptimizationTestMixin):
z = aesara.sparse.matrix("csr", dtype=self.dtype, name="z") z = aesara.sparse.matrix("csr", dtype=self.dtype, name="z")
cond = iscalar("cond") cond = iscalar("cond")
with pytest.raises(TypeError): with pytest.raises(NotImplementedError):
ifelse(cond, x, y) ifelse(cond, x, y)
with pytest.raises(TypeError): with pytest.raises(NotImplementedError):
ifelse(cond, y, x) ifelse(cond, y, x)
with pytest.raises(TypeError): with pytest.raises(NotImplementedError):
ifelse(cond, x, z) ifelse(cond, x, z)
with pytest.raises(TypeError): with pytest.raises(NotImplementedError):
ifelse(cond, z, x) ifelse(cond, z, x)
with pytest.raises(TypeError): with pytest.raises(TypeError):
ifelse(cond, y, z) ifelse(cond, y, z)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论