提交 6a17990a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba sparse: Allow no-op tocsc and tocsr

上级 23205754
...@@ -289,98 +289,112 @@ def overload_sparse_astype(matrix, dtype): ...@@ -289,98 +289,112 @@ def overload_sparse_astype(matrix, dtype):
return astype return astype
@overload_method(CSCMatrixType, "tocsr") @overload_method(CSMatrixType, "tocsr")
def overload_tocsr(matrix): def overload_tocsr(matrix):
def to_csr(matrix): if isinstance(matrix, CSRMatrixType):
n_row, n_col = matrix.shape
csc_ptr = matrix.indptr.view(np.uint32)
csc_ind = matrix.indices.view(np.uint32)
csc_data = matrix.data
nnz = csc_ptr[n_col]
csr_ptr = np.empty(n_row + 1, dtype=np.uint32) def to_csr(matrix):
csr_ind = np.empty(nnz, dtype=np.uint32) return matrix
csr_data = np.empty(nnz, dtype=matrix.data.dtype)
csr_ptr[:n_row] = 0 else: # CSCMatrix
for n in range(nnz): def to_csr(matrix):
csr_ptr[csc_ind[n]] += 1 n_row, n_col = matrix.shape
csc_ptr = matrix.indptr.view(np.uint32)
csc_ind = matrix.indices.view(np.uint32)
csc_data = matrix.data
nnz = csc_ptr[n_col]
cumsum = 0 csr_ptr = np.empty(n_row + 1, dtype=np.uint32)
for row in range(n_row): csr_ind = np.empty(nnz, dtype=np.uint32)
temp = csr_ptr[row] csr_data = np.empty(nnz, dtype=matrix.data.dtype)
csr_ptr[row] = cumsum
cumsum += temp
csr_ptr[n_row] = nnz
for col_idx in range(n_col): csr_ptr[:n_row] = 0
for jj in range(csc_ptr[col_idx], csc_ptr[col_idx + 1]):
row_idx = csc_ind[jj]
dest = csr_ptr[row_idx]
csr_ind[dest] = col_idx for n in range(nnz):
csr_data[dest] = csc_data[jj] csr_ptr[csc_ind[n]] += 1
csr_ptr[row_idx] += 1 cumsum = 0
for row in range(n_row):
temp = csr_ptr[row]
csr_ptr[row] = cumsum
cumsum += temp
csr_ptr[n_row] = nnz
last = 0 for col_idx in range(n_col):
for row_idx in range(n_row + 1): for jj in range(csc_ptr[col_idx], csc_ptr[col_idx + 1]):
temp = csr_ptr[row_idx] row_idx = csc_ind[jj]
csr_ptr[row_idx] = last dest = csr_ptr[row_idx]
last = temp
return csr_matrix_from_components( csr_ind[dest] = col_idx
csr_data, csr_ind.view(np.int32), csr_ptr.view(np.int32), matrix.shape csr_data[dest] = csc_data[jj]
)
csr_ptr[row_idx] += 1
last = 0
for row_idx in range(n_row + 1):
temp = csr_ptr[row_idx]
csr_ptr[row_idx] = last
last = temp
return csr_matrix_from_components(
csr_data, csr_ind.view(np.int32), csr_ptr.view(np.int32), matrix.shape
)
return to_csr return to_csr
@overload_method(CSRMatrixType, "tocsc") @overload_method(CSMatrixType, "tocsc")
def overload_tocsc(matrix): def overload_tocsc(matrix):
def to_csc(matrix): if isinstance(matrix, CSCMatrixType):
n_row, n_col = matrix.shape
csr_ptr = matrix.indptr.view(np.uint32)
csr_ind = matrix.indices.view(np.uint32)
csr_data = matrix.data
nnz = csr_ptr[n_row]
csc_ptr = np.empty(n_col + 1, dtype=np.uint32) def to_csc(matrix):
csc_ind = np.empty(nnz, dtype=np.uint32) return matrix
csc_data = np.empty(nnz, dtype=matrix.data.dtype)
csc_ptr[:n_col] = 0 else: # CSRMatrix
for n in range(nnz): def to_csc(matrix):
csc_ptr[csr_ind[n]] += 1 n_row, n_col = matrix.shape
csr_ptr = matrix.indptr.view(np.uint32)
csr_ind = matrix.indices.view(np.uint32)
csr_data = matrix.data
nnz = csr_ptr[n_row]
cumsum = 0 csc_ptr = np.empty(n_col + 1, dtype=np.uint32)
for col in range(n_col): csc_ind = np.empty(nnz, dtype=np.uint32)
temp = csc_ptr[col] csc_data = np.empty(nnz, dtype=matrix.data.dtype)
csc_ptr[col] = cumsum
cumsum += temp
csc_ptr[n_col] = nnz
for row in range(n_row): csc_ptr[:n_col] = 0
for jj in range(csr_ptr[row], csr_ptr[row + 1]):
col = csr_ind[jj]
dest = csc_ptr[col]
csc_ind[dest] = row for n in range(nnz):
csc_data[dest] = csr_data[jj] csc_ptr[csr_ind[n]] += 1
csc_ptr[col] += 1 cumsum = 0
for col in range(n_col):
temp = csc_ptr[col]
csc_ptr[col] = cumsum
cumsum += temp
csc_ptr[n_col] = nnz
last = 0 for row in range(n_row):
for col in range(n_col + 1): for jj in range(csr_ptr[row], csr_ptr[row + 1]):
temp = csc_ptr[col] col = csr_ind[jj]
csc_ptr[col] = last dest = csc_ptr[col]
last = temp
return csc_matrix_from_components( csc_ind[dest] = row
csc_data, csc_ind.view(np.int32), csc_ptr.view(np.int32), matrix.shape csc_data[dest] = csr_data[jj]
)
csc_ptr[col] += 1
last = 0
for col in range(n_col + 1):
temp = csc_ptr[col]
csc_ptr[col] = last
last = temp
return csc_matrix_from_components(
csc_data, csc_ind.view(np.int32), csc_ptr.view(np.int32), matrix.shape
)
return to_csc return to_csc
......
...@@ -276,3 +276,26 @@ def test_sparse_dense_from_sparse(format): ...@@ -276,3 +276,26 @@ def test_sparse_dense_from_sparse(format):
x_test = sp.sparse.random(5, 3, density=0.5, format=format) x_test = sp.sparse.random(5, 3, density=0.5, format=format)
y = ps.dense_from_sparse(x) y = ps.dense_from_sparse(x)
compare_numba_and_py_sparse([x], y, [x_test]) compare_numba_and_py_sparse([x], y, [x_test])
def test_sparse_conversion():
@numba.njit
def to_csr(matrix):
return matrix.tocsr()
@numba.njit
def to_csc(matrix):
return matrix.tocsc()
x_csr = scipy.sparse.random(5, 5, density=0.5, format="csr")
x_csc = x_csr.tocsc()
x_dense = x_csr.todense()
for x_inp in (x_csr, x_csc):
for output_format in ("csr", "csc"):
if output_format == "csr":
res = to_csr(x_inp)
else:
res = to_csc(x_inp)
assert res.format == output_format
np.testing.assert_array_equal(res.todense(), x_dense)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论