Unverified 提交 207b0c6e authored 作者: Tomás Capretto's avatar Tomás Capretto 提交者: GitHub

Implement sparse dot product in numba backend (#1854)

* Add comment regarding meaning of '*' operator with sparse matrices * Implement sparse dot product in numba * Overload T attribute in numba sparse matrices * Implement toarray * Add tocsr
上级 51b3885f
......@@ -10,7 +10,7 @@ from pytensor.link.numba.dispatch.basic import (
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.sparse.variable import CSMatrixType
from pytensor.sparse import CSM, Cast, CSMProperties
from pytensor.sparse import CSM, Cast, CSMProperties, DenseFromSparse, Transpose
@overload(numba_deepcopy)
......@@ -66,3 +66,21 @@ def numba_funcify_Cast(op, node, **kwargs):
return x.astype(out_dtype)
return cast
@register_funcify_default_op_cache_key(Transpose)
def numba_funcify_Transpose(op, node, **kwargs):
@numba_basic.numba_njit
def transpose(x):
return x.T
return transpose
@register_funcify_default_op_cache_key(DenseFromSparse)
def numba_funcify_DenseFromSparse(op, node, **kwargs):
@numba_basic.numba_njit
def to_array(x):
return x.toarray()
return to_array
from hashlib import sha256
import numpy as np
import scipy.sparse as sp
import pytensor.sparse.basic as psb
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
from pytensor.sparse import SparseDenseMultiply, SparseDenseVectorMultiply
from pytensor.link.numba.dispatch.basic import (
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.sparse import (
Dot,
SparseDenseMultiply,
SparseDenseVectorMultiply,
StructuredDot,
)
@register_funcify_default_op_cache_key(SparseDenseMultiply)
......@@ -88,3 +102,271 @@ def numba_funcify_SparseDenseMultiply(op, node, **kwargs):
return z
return sparse_dense_multiply
@register_funcify_and_cache_key(Dot)
@register_funcify_and_cache_key(StructuredDot)
def numba_funcify_SparseDot(op, node, **kwargs):
# Inputs can be of types: (sparse, dense), (dense, sparse), (sparse, sparse).
# Dot always returns a dense result.
# StructuredDot returns a sparse object when all entries are sparse, otherwise dense.
x, y = node.inputs
[z] = node.outputs
out_dtype = z.type.dtype
x_is_sparse = psb._is_sparse_variable(x)
y_is_sparse = psb._is_sparse_variable(y)
z_is_sparse = psb._is_sparse_variable(z)
x_format = x.type.format if x_is_sparse else None
y_format = y.type.format if y_is_sparse else None
cache_key = sha256(
str(
(
type(op),
x_format,
y_format,
z_is_sparse,
y.type.ndim,
y.type.broadcastable,
)
).encode()
).hexdigest()
if x_is_sparse and y_is_sparse:
# General spmspm algorithm in CSR format
@numba_basic.numba_njit
def _spmspm(n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data):
# Pass 1
x_ind = x_ind.view(np.uint32)
y_ind = y_ind.view(np.uint32)
x_ptr = x_ptr.view(np.uint32)
y_ptr = y_ptr.view(np.uint32)
output_nnz = 0
mask = np.full(n_col, -1, dtype=np.int32)
for i in range(n_row):
row_nnz = 0
for jj in range(x_ptr[i], x_ptr[i + 1]):
j = x_ind[jj]
for kk in range(y_ptr[j], y_ptr[j + 1]):
k = y_ind[kk]
if mask[k] != i:
mask[k] = i
row_nnz += 1
output_nnz += row_nnz
# Pass 2
z_ptr = np.empty(n_row + 1, dtype=np.uint32)
z_ind = np.empty(output_nnz, dtype=np.uint32)
z_data = np.empty(output_nnz, dtype=out_dtype)
# Refill original mask for reuse
mask.fill(-1)
sums = np.zeros(n_col, dtype=out_dtype)
nnz = 0
z_ptr[0] = 0
for i in range(n_row):
head = -2
length = 0
for jj in range(x_ptr[i], x_ptr[i + 1]):
j = x_ind[jj]
v = x_data[jj]
for kk in range(y_ptr[j], y_ptr[j + 1]):
k = y_ind[kk]
sums[k] += v * y_data[kk]
if mask[k] == -1:
mask[k] = head
head = k
length += 1
for _ in range(length):
if sums[head] != 0:
z_ind[nnz] = head
z_data[nnz] = sums[head]
nnz += 1
temp = head
head = mask[head]
mask[temp] = -1
sums[temp] = 0
z_ptr[i + 1] = nnz
return z_ptr.view(np.int32), z_ind.view(np.int32), z_data
@numba_basic.numba_njit
def spmspm(x, y):
if x_format != "csr":
x = x.tocsr()
if y_format != "csr":
y = y.tocsr()
x_ptr, x_ind, x_data = x.indptr, x.indices, x.data
y_ptr, y_ind, y_data = y.indptr, y.indices, y.data
n_row, n_col = x.shape[0], y.shape[1]
z_ptr, z_ind, z_data = _spmspm(
n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data
)
output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
# Dot returns a dense result even in spMspM
if not z_is_sparse:
return output.toarray()
# StructuredDot returns in the format of 'x'
if x_format == "csc":
return output.tocsc()
return output
return spmspm, cache_key
# Only one of 'x' or 'y' is sparse, not both.
# Before using a general dot(sparse-matrix, dense-matrix) algorithm,
# we check if we can rely on the less intensive (sparse-matrix, dense-vector) algorithm (spmv).
y_is_1d_like = y.type.ndim == 1 or (y.type.ndim == 2 and y.type.shape[1] == 1)
x_is_1d = x.type.ndim == 1
if (x_is_sparse and y_is_1d_like) or (y_is_sparse and x_is_1d):
# We can use spmv
@numba_basic.numba_njit
def _spmdv_csr(x_ptr, x_ind, x_data, x_shape, y):
n_row = x_shape[0]
x_ptr = x_ptr.view(np.uint32)
x_ind = x_ind.view(np.uint32)
output = np.zeros(n_row, dtype=out_dtype)
for row_idx in range(n_row):
acc = 0.0
for k in range(x_ptr[row_idx], x_ptr[row_idx + 1]):
acc += x_data[k] * y[x_ind[k]]
output[row_idx] = acc
return output
@numba_basic.numba_njit
def _spmdv_csc(x_ptr, x_ind, x_data, x_shape, y):
n_row, n_col = x_shape
x_ptr = x_ptr.view(np.uint32)
x_ind = x_ind.view(np.uint32)
output = np.zeros(n_row, dtype=out_dtype)
for col_idx in range(n_col):
yj = y[col_idx]
for k in range(x_ptr[col_idx], x_ptr[col_idx + 1]):
output[x_ind[k]] += x_data[k] * yj
return output
if x_is_sparse:
if x_format == "csr":
_spmdv = _spmdv_csr
else:
_spmdv = _spmdv_csc
if y.type.ndim == 1:
@numba_basic.numba_njit
def spmdv(x, y):
assert x.shape[1] == y.shape[0]
return _spmdv(x.indptr, x.indices, x.data, x.shape, y)
else:
@numba_basic.numba_njit
def spmdv(x, y):
# Output must be 2d.
assert x.shape[1] == y.shape[0]
return _spmdv(x.indptr, x.indices, x.data, x.shape, y[:, 0])[
:, None
]
return spmdv, cache_key
else: # y_is_sparse
# Rely on: z = dot(x, y) -> z^T = dot(x, y)^T -> z^T = dot(y^T, x^T)
if y_format == "csr":
_spmdv = _spmdv_csc
else: # csc
_spmdv = _spmdv_csr
@numba_basic.numba_njit
def spmdv(x, y):
# SciPy treats (p, ) * (p, k) as (1, p) @ (p, k),
# but returns the result as of shape (k, ).
assert x.shape[0] == y.shape[0]
yT = y.T # (k, p)
return _spmdv(yT.indptr, yT.indices, yT.data, yT.shape, x)
return spmdv, cache_key
# Only one of 'x' or 'y' is sparse, and we can't use spmdv.
# We know we have to rely on the general (sparse-matrix, dense-matrix) dot product (spmdm).
@numba_basic.numba_njit
def spmdm_csr(x, y):
assert x.shape[1] == y.shape[0]
n = x.shape[0]
k = y.shape[1]
z = np.zeros((n, k), dtype=out_dtype)
x_ind = x.indices.view(np.uint32)
x_ptr = x.indptr.view(np.uint32)
x_data = x.data
for row_idx in range(n):
for idx in range(x_ptr[row_idx], x_ptr[row_idx + 1]):
col_idx = x_ind[idx]
value = x_data[idx]
z[row_idx] += value * y[col_idx]
return z
@numba_basic.numba_njit
def spmdm_csc(x, y):
assert x.shape[1] == y.shape[0]
k = y.shape[1]
n = x.shape[0]
p = x.shape[1]
z = np.zeros((n, k), dtype=out_dtype)
x_ind = x.indices.view(np.uint32)
x_ptr = x.indptr.view(np.uint32)
x_data = x.data
for col_idx in range(p):
for idx in range(x_ptr[col_idx], x_ptr[col_idx + 1]):
row_idx = x_ind[idx]
value = x_data[idx]
z[row_idx] += value * y[col_idx]
return z
if x_is_sparse:
if x_format == "csr":
return spmdm_csr, cache_key
else:
return spmdm_csc, cache_key
if y_is_sparse:
# We don't implement a dense-sparse dot product.
# Instead, we use properties of transpose:
# z = dot(x, y) -> z^T = dot(x, y)^T -> z^T = dot(y^T, x^T)
# which allows us to reuse sparse-dense dot.
if y_format == "csr":
# y.T will be CSC
@numba_basic.numba_njit
def dmspm(x, y):
return spmdm_csc(y.T, x.T).T
else:
# y.T will be CSR
@numba_basic.numba_njit
def dmspm(x, y):
return spmdm_csr(y.T, x.T).T
return dmspm, cache_key
......@@ -225,6 +225,28 @@ def overload_sparse_ndim(matrix):
return lambda matrix: 2
@overload_attribute(CSMatrixType, "T")
def overload_sparse_T(matrix):
match matrix:
case CSRMatrixType():
builder = csc_matrix_from_components
case CSCMatrixType():
builder = csr_matrix_from_components
case _:
return
def transpose(matrix):
n_row, n_col = matrix.shape
return builder(
matrix.data.copy(),
matrix.indices.copy(),
matrix.indptr.copy(),
(n_col, n_row),
)
return transpose
@overload_method(CSMatrixType, "copy")
def overload_sparse_copy(matrix):
match matrix:
......@@ -265,3 +287,134 @@ def overload_sparse_astype(matrix, dtype):
)
return astype
@overload_method(CSCMatrixType, "tocsr")
def overload_tocsr(matrix):
def to_csr(matrix):
n_row, n_col = matrix.shape
csc_ptr = matrix.indptr.view(np.uint32)
csc_ind = matrix.indices.view(np.uint32)
csc_data = matrix.data
nnz = csc_ptr[n_col]
csr_ptr = np.empty(n_row + 1, dtype=np.uint32)
csr_ind = np.empty(nnz, dtype=np.uint32)
csr_data = np.empty(nnz, dtype=matrix.data.dtype)
csr_ptr[:n_row] = 0
for n in range(nnz):
csr_ptr[csc_ind[n]] += 1
cumsum = 0
for row in range(n_row):
temp = csr_ptr[row]
csr_ptr[row] = cumsum
cumsum += temp
csr_ptr[n_row] = nnz
for col_idx in range(n_col):
for jj in range(csc_ptr[col_idx], csc_ptr[col_idx + 1]):
row_idx = csc_ind[jj]
dest = csr_ptr[row_idx]
csr_ind[dest] = col_idx
csr_data[dest] = csc_data[jj]
csr_ptr[row_idx] += 1
last = 0
for row_idx in range(n_row + 1):
temp = csr_ptr[row_idx]
csr_ptr[row_idx] = last
last = temp
return csr_matrix_from_components(
csr_data, csr_ind.view(np.int32), csr_ptr.view(np.int32), matrix.shape
)
return to_csr
@overload_method(CSRMatrixType, "tocsc")
def overload_tocsc(matrix):
def to_csc(matrix):
n_row, n_col = matrix.shape
csr_ptr = matrix.indptr.view(np.uint32)
csr_ind = matrix.indices.view(np.uint32)
csr_data = matrix.data
nnz = csr_ptr[n_row]
csc_ptr = np.empty(n_col + 1, dtype=np.uint32)
csc_ind = np.empty(nnz, dtype=np.uint32)
csc_data = np.empty(nnz, dtype=matrix.data.dtype)
csc_ptr[:n_col] = 0
for n in range(nnz):
csc_ptr[csr_ind[n]] += 1
cumsum = 0
for col in range(n_col):
temp = csc_ptr[col]
csc_ptr[col] = cumsum
cumsum += temp
csc_ptr[n_col] = nnz
for row in range(n_row):
for jj in range(csr_ptr[row], csr_ptr[row + 1]):
col = csr_ind[jj]
dest = csc_ptr[col]
csc_ind[dest] = row
csc_data[dest] = csr_data[jj]
csc_ptr[col] += 1
last = 0
for col in range(n_col + 1):
temp = csc_ptr[col]
csc_ptr[col] = last
last = temp
return csc_matrix_from_components(
csc_data, csc_ind.view(np.int32), csc_ptr.view(np.int32), matrix.shape
)
return to_csc
@overload_method(CSMatrixType, "toarray")
def overload_toarray(matrix):
match matrix:
case CSRMatrixType():
def to_array(matrix):
indptr = matrix.indptr.view(np.uint32)
indices = matrix.indices.view(np.uint32)
n_row = matrix.shape[0]
dense = np.zeros(matrix.shape, dtype=matrix.data.dtype)
for row_idx in range(n_row):
for k in range(indptr[row_idx], indptr[row_idx + 1]):
col_idx = indices[k]
dense[row_idx, col_idx] = matrix.data[k]
return dense
return to_array
case CSCMatrixType():
def to_array(matrix):
indptr = matrix.indptr.view(np.uint32)
indices = matrix.indices.view(np.uint32)
n_col = matrix.shape[1]
dense = np.zeros(matrix.shape, dtype=matrix.data.dtype)
for col_idx in range(n_col):
for k in range(indptr[col_idx], indptr[col_idx + 1]):
row_idx = indices[k]
dense[row_idx, col_idx] = matrix.data[k]
return dense
return to_array
case _:
return
......@@ -1365,6 +1365,8 @@ class StructuredDot(Op):
"shape mismatch in StructuredDot.perform", (a.shape, b.shape)
)
# Multiplication of objects of `*_matrix` type means dot product
# The result can be sparse or dense, depending on the inputs.
variable = a * b
if isinstance(node.outputs[0].type, SparseTensorType):
assert psb._is_sparse(variable)
......@@ -1902,6 +1904,7 @@ class Dot(Op):
if not x_is_sparse and not y_is_sparse:
raise TypeError(x)
# Multiplication of objects of `*_matrix` type means dot product
rval = x * y
if x_is_sparse and y_is_sparse:
......
......@@ -30,6 +30,12 @@ def sparse_assert_fn(a, b):
a_is_sparse = sp.sparse.issparse(a)
assert a_is_sparse == sp.sparse.issparse(b)
if a_is_sparse:
# Attributes can be compared only if both matrices have sorted indices
if not a.has_sorted_indices:
a = a.sorted_indices()
if not b.has_sorted_indices:
b = b.sorted_indices()
assert a.format == b.format
assert a.dtype == b.dtype
assert a.shape == b.shape
......@@ -262,3 +268,11 @@ def test_sparse_deepcopy(format):
x = ps.matrix(shape=(3, 3), format=format)
x_test = sp.sparse.random(3, 3, density=0.5, format=format)
compare_numba_and_py_sparse([x], [x], [x_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_dense_from_sparse(format):
x = ps.matrix(shape=(5, 3), format=format)
x_test = sp.sparse.random(5, 3, density=0.5, format=format)
y = ps.dense_from_sparse(x)
compare_numba_and_py_sparse([x], y, [x_test])
......@@ -9,6 +9,8 @@ from tests.link.numba.sparse.test_basic import compare_numba_and_py_sparse
pytestmark = pytest.mark.filterwarnings("error")
DOT_SHAPES = [((20, 11), (11, 4)), ((10, 3), (3, 1)), ((1, 10), (10, 5))]
@pytest.mark.parametrize("format", ["csr", "csc"])
@pytest.mark.parametrize("y_ndim", [0, 1, 2])
......@@ -26,3 +28,70 @@ def test_sparse_dense_multiply(y_ndim, format):
z,
[x_test, y_test],
)
@pytest.mark.parametrize("op", [ps.dot, ps.structured_dot])
@pytest.mark.parametrize("sp_format", ["csr", "csc"])
@pytest.mark.parametrize("x_shape, y_shape", DOT_SHAPES)
def test_dot_sparse_dense(op, sp_format, x_shape, y_shape):
x = ps.matrix(format=sp_format, name="x", shape=x_shape)
y = pt.matrix("y", shape=y_shape)
z = op(x, y)
rng = np.random.default_rng(sum(map(ord, sp_format)) + sum(x_shape) + sum(y_shape))
x_test = scipy.sparse.random(
*x_shape, density=0.5, format=sp_format, random_state=rng
)
y_test = rng.normal(size=y_shape)
compare_numba_and_py_sparse([x, y], z, [x_test, y_test])
@pytest.mark.parametrize("op", [ps.dot, ps.structured_dot])
@pytest.mark.parametrize("sp_format", ["csr", "csc"])
@pytest.mark.parametrize("x_shape, y_shape", DOT_SHAPES)
def test_dot_dense_sparse(op, sp_format, x_shape, y_shape):
x = pt.matrix(name="x", shape=x_shape)
y = ps.matrix(format=sp_format, name="y", shape=y_shape)
z = op(x, y)
rng = np.random.default_rng(sum(map(ord, sp_format)) + sum(x_shape) + sum(y_shape))
x_test = rng.normal(size=x_shape)
y_test = scipy.sparse.random(
*y_shape, density=0.5, format=sp_format, random_state=rng
)
compare_numba_and_py_sparse([x, y], z, [x_test, y_test])
@pytest.mark.parametrize("op", [ps.dot, ps.structured_dot])
@pytest.mark.parametrize("x_format", ["csr", "csc"])
@pytest.mark.parametrize("y_format", ["csr", "csc"])
@pytest.mark.parametrize("x_shape, y_shape", DOT_SHAPES)
def test_sparse_dot_sparse_sparse(op, x_format, y_format, x_shape, y_shape):
x = ps.matrix(x_format, name="x", shape=x_shape)
y = ps.matrix(y_format, name="y", shape=y_shape)
z = op(x, y)
rng = np.random.default_rng(sum(map(ord, x_format)) + sum(map(ord, y_format)))
x_test = scipy.sparse.random(
*x_shape, density=0.5, format=x_format, random_state=rng
)
y_test = scipy.sparse.random(
*y_shape, density=0.5, format=y_format, random_state=rng
)
compare_numba_and_py_sparse([x, y], z, [x_test, y_test])
@pytest.mark.parametrize("sp_format", ["csr", "csc"])
def test_sparse_spmv(sp_format):
x = ps.matrix(format=sp_format, name="x", shape=(20, 6))
y = pt.vector("y", shape=(6,))
z = ps.dot(x, y)
rng = np.random.default_rng(sp_format == "csr")
x_test = scipy.sparse.random(20, 6, density=0.5, format=sp_format, random_state=rng)
y_test = rng.normal(size=(6,))
compare_numba_and_py_sparse([x, y], z, [x_test, y_test])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论