提交 6a17990a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba sparse: Allow no-op tocsc and tocsr

上级 23205754
......@@ -289,8 +289,15 @@ def overload_sparse_astype(matrix, dtype):
return astype
@overload_method(CSCMatrixType, "tocsr")
@overload_method(CSMatrixType, "tocsr")
def overload_tocsr(matrix):
if isinstance(matrix, CSRMatrixType):
def to_csr(matrix):
return matrix
else: # CSCMatrix
def to_csr(matrix):
n_row, n_col = matrix.shape
csc_ptr = matrix.indptr.view(np.uint32)
......@@ -337,8 +344,15 @@ def overload_tocsr(matrix):
return to_csr
@overload_method(CSRMatrixType, "tocsc")
@overload_method(CSMatrixType, "tocsc")
def overload_tocsc(matrix):
if isinstance(matrix, CSCMatrixType):
def to_csc(matrix):
return matrix
else: # CSRMatrix
def to_csc(matrix):
n_row, n_col = matrix.shape
csr_ptr = matrix.indptr.view(np.uint32)
......
......@@ -276,3 +276,26 @@ def test_sparse_dense_from_sparse(format):
x_test = sp.sparse.random(5, 3, density=0.5, format=format)
y = ps.dense_from_sparse(x)
compare_numba_and_py_sparse([x], y, [x_test])
def test_sparse_conversion():
@numba.njit
def to_csr(matrix):
return matrix.tocsr()
@numba.njit
def to_csc(matrix):
return matrix.tocsc()
x_csr = scipy.sparse.random(5, 5, density=0.5, format="csr")
x_csc = x_csr.tocsc()
x_dense = x_csr.todense()
for x_inp in (x_csr, x_csc):
for output_format in ("csr", "csc"):
if output_format == "csr":
res = to_csr(x_inp)
else:
res = to_csc(x_inp)
assert res.format == output_format
np.testing.assert_array_equal(res.todense(), x_dense)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论