提交 f0603aac authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement sparse HStack and VStack in numba backend

上级 c55a97bb
import numpy as np
import scipy as sp
from numba import literal_unroll
from numba.extending import overload
from pytensor import config
......@@ -15,8 +16,10 @@ from pytensor.sparse import (
Cast,
CSMProperties,
DenseFromSparse,
HStack,
SparseFromDense,
Transpose,
VStack,
)
......@@ -109,3 +112,129 @@ def numba_funcify_SparseFromDense(op, node, **kwargs):
return sp.sparse.csc_matrix(matrix)
return dense_to_csc
@register_funcify_default_op_cache_key(HStack)
def numba_funcify_HStack(op, node, **kwargs):
output_format = op.format
out_dtype = np.dtype(op.dtype)
@numba_basic.numba_njit
def hstack_csc(*blocks):
n_rows = blocks[0].shape[0]
total_n_cols = 0
total_nnz = 0
blocks_csc = []
for block in literal_unroll(blocks):
if block.shape[0] != n_rows:
raise ValueError("Mismatching dimensions along axis 0")
# `hstack` operates on CSC inputs, so we convert each block to CSC.
# This allocates memory for CSR inputs, but not for inputs already in CSC format.
block_csc = block.tocsc()
blocks_csc.append(block_csc)
# Count number of columns and non-zeros for the output matrix.
total_nnz += block_csc.indptr[block_csc.shape[1]]
total_n_cols += block_csc.shape[1]
data = np.empty(total_nnz, dtype=out_dtype)
indices = np.empty(total_nnz, dtype=np.int32)
indptr = np.empty(total_n_cols + 1, dtype=np.int32)
indptr[0] = 0
# Append each CSC block into the preallocated output by
# tracking global offsets for columns (`col_offset`) and nonzeros (`nnz_offset`).
col_offset = 0
nnz_offset = 0
for block in blocks_csc:
block_n_cols = block.shape[1]
block_nnz = block.indptr[block_n_cols]
data[nnz_offset : nnz_offset + block_nnz] = block.data
indices[nnz_offset : nnz_offset + block_nnz] = block.indices
for col_idx in range(block_n_cols):
indptr[col_offset + col_idx + 1] = (
nnz_offset + block.indptr[col_idx + 1]
)
nnz_offset += block_nnz
col_offset += block_n_cols
return sp.sparse.csc_matrix(
(data, indices, indptr), shape=(n_rows, total_n_cols)
)
if output_format == "csc":
return hstack_csc
@numba_basic.numba_njit
def hstack_csr(*blocks):
return hstack_csc(*blocks).tocsr()
return hstack_csr
@register_funcify_default_op_cache_key(VStack)
def numba_funcify_VStack(op, node, **kwargs):
output_format = op.format
out_dtype = np.dtype(op.dtype)
@numba_basic.numba_njit
def vstack_csr(*blocks):
n_cols = blocks[0].shape[1]
total_n_rows = 0
total_nnz = 0
blocks_csr = []
for block in literal_unroll(blocks):
if block.shape[1] != n_cols:
raise ValueError("Mismatching dimensions along axis 1")
# `vstack` operates on CSR inputs, so we convert each block to CSR.
# This allocates memory for CSC inputs, but not for inputs already in CSR format.
block_csr = block.tocsr()
blocks_csr.append(block_csr)
# Count number of rows and non-zeros for the output matrix.
total_nnz += block_csr.indptr[block_csr.shape[0]]
total_n_rows += block_csr.shape[0]
data = np.empty(total_nnz, dtype=out_dtype)
indices = np.empty(total_nnz, dtype=np.int32)
indptr = np.empty(total_n_rows + 1, dtype=np.int32)
indptr[0] = 0
# Append each CSR block into the preallocated output by
# tracking global offsets for rows (`row_offset`) and nonzeros (`nnz_offset`).
row_offset = 0
nnz_offset = 0
for block in blocks_csr:
block_n_rows = block.shape[0]
block_nnz = block.indptr[block_n_rows]
data[nnz_offset : nnz_offset + block_nnz] = block.data
indices[nnz_offset : nnz_offset + block_nnz] = block.indices
for row_idx in range(block_n_rows):
indptr[row_offset + row_idx + 1] = (
nnz_offset + block.indptr[row_idx + 1]
)
nnz_offset += block_nnz
row_offset += block_n_rows
return sp.sparse.csr_matrix(
(data, indices, indptr), shape=(total_n_rows, n_cols)
)
if output_format == "csr":
return vstack_csr
@numba_basic.numba_njit
def vstack_csc(*blocks):
return vstack_csr(*blocks).tocsc()
return vstack_csc
......@@ -3,7 +3,6 @@ from sys import getrefcount
import numpy as np
import pytest
import scipy
import scipy as sp
import pytensor.sparse as ps
......@@ -18,7 +17,7 @@ numba = pytest.importorskip("numba")
# Make sure the Numba customizations are loaded
import pytensor.link.numba.dispatch.sparse # noqa: F401
from pytensor import config
from pytensor import config, function
from pytensor.sparse import SparseTensorType
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -75,9 +74,9 @@ def test_sparse_boxing():
def test_sparse_creation_refcount():
@numba.njit
def create_csr_matrix(data, indices, ind_ptr):
return scipy.sparse.csr_matrix((data, indices, ind_ptr), shape=(5, 5))
return sp.sparse.csr_matrix((data, indices, ind_ptr), shape=(5, 5))
x = scipy.sparse.random(5, 5, density=0.5, format="csr")
x = sp.sparse.random(5, 5, density=0.5, format="csr")
x_data = x.data
x_indptr = x.indptr
......@@ -102,7 +101,7 @@ def test_sparse_passthrough_refcount():
def identity(a):
return a
x = scipy.sparse.random(5, 5, density=0.5, format="csr")
x = sp.sparse.random(5, 5, density=0.5, format="csr")
x_data = x.data
assert getrefcount(x_data) == 3
......@@ -287,7 +286,7 @@ def test_sparse_conversion():
def to_csc(matrix):
return matrix.tocsc()
x_csr = scipy.sparse.random(5, 5, density=0.5, format="csr")
x_csr = sp.sparse.random(5, 5, density=0.5, format="csr")
x_csc = x_csr.tocsc()
x_dense = x_csr.todense()
......@@ -313,3 +312,123 @@ def test_sparse_from_dense(format):
y = ps.csc_from_dense(x)
compare_numba_and_py_sparse([x], y, [x_test])
@pytest.mark.parametrize("output_format", ("csr", "csc"))
@pytest.mark.parametrize(
"input_formats",
(
("csr", "csr", "csr"),
("csc", "csc", "csc"),
("csr", "csc", "csr"),
("csc", "csr", "csc"),
("csc", "csc", "csr"),
),
)
def test_sparse_hstack(output_format, input_formats):
x1 = ps.matrix(
name="x1", shape=(7, 2), format=input_formats[0], dtype=config.floatX
)
x2 = ps.matrix(
name="x2", shape=(7, 1), format=input_formats[1], dtype=config.floatX
)
x3 = ps.matrix(
name="x3", shape=(7, 5), format=input_formats[2], dtype=config.floatX
)
z = ps.hstack([x1, x2, x3], format=output_format, dtype=config.floatX)
x1_test = sp.sparse.random(
7,
2,
density=0.5,
format=input_formats[0],
dtype=config.floatX,
)
x2_test = sp.sparse.random(
7,
1,
density=0.3,
format=input_formats[1],
dtype=config.floatX,
)
x3_test = sp.sparse.random(
7,
5,
density=0.4,
format=input_formats[2],
dtype=config.floatX,
)
compare_numba_and_py_sparse([x1, x2, x3], z, [x1_test, x2_test, x3_test])
@pytest.mark.parametrize("output_format", ("csr", "csc"))
@pytest.mark.parametrize(
"input_formats",
(
("csr", "csr", "csr"),
("csc", "csc", "csc"),
("csr", "csc", "csr"),
("csc", "csr", "csc"),
("csc", "csc", "csr"),
),
)
def test_sparse_vstack(output_format, input_formats):
x1 = ps.matrix(
name="x1", shape=(2, 11), format=input_formats[0], dtype=config.floatX
)
x2 = ps.matrix(
name="x2", shape=(1, 11), format=input_formats[1], dtype=config.floatX
)
x3 = ps.matrix(
name="x3", shape=(5, 11), format=input_formats[2], dtype=config.floatX
)
z = ps.vstack([x1, x2, x3], format=output_format, dtype=config.floatX)
x1_test = sp.sparse.random(
2,
11,
density=0.4,
format=input_formats[0],
dtype=config.floatX,
)
x2_test = sp.sparse.random(
1,
11,
density=0.5,
format=input_formats[1],
dtype=config.floatX,
)
x3_test = sp.sparse.random(
5,
11,
density=0.2,
format=input_formats[2],
dtype=config.floatX,
)
compare_numba_and_py_sparse([x1, x2, x3], z, [x1_test, x2_test, x3_test])
def test_sparse_hstack_mismatched_rows_raises():
x = ps.matrix(name="x", shape=(3, 5), format="csr", dtype=config.floatX)
y = ps.matrix(name="y", shape=(4, 7), format="csr", dtype=config.floatX)
z = ps.hstack([x, y], format="csr", dtype=config.floatX)
fn = function([x, y], z, mode="NUMBA")
x_test = sp.sparse.random(3, 5, density=0.4, format="csr", dtype=config.floatX)
y_test = sp.sparse.random(4, 7, density=0.4, format="csr", dtype=config.floatX)
with pytest.raises(ValueError, match="Mismatching dimensions along axis 0"):
fn(x_test, y_test)
def test_sparse_vstack_mismatched_cols_raises():
x = ps.matrix(name="x", shape=(10, 3), format="csr", dtype=config.floatX)
y = ps.matrix(name="y", shape=(13, 4), format="csr", dtype=config.floatX)
z = ps.vstack([x, y], format="csr", dtype=config.floatX)
fn = function([x, y], z, mode="NUMBA")
x_test = sp.sparse.random(10, 3, density=0.4, format="csr", dtype=config.floatX)
y_test = sp.sparse.random(13, 4, density=0.4, format="csr", dtype=config.floatX)
with pytest.raises(ValueError, match="Mismatching dimensions along axis 1"):
fn(x_test, y_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论