提交 557307a6 authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Numba sparse: handle dot product between matrices more granularly to reduce the…

Numba sparse: handle dot product between matrices more granularly to reduce the number of of format conversions
上级 18fbf45c
...@@ -121,7 +121,7 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -121,7 +121,7 @@ def numba_funcify_SparseDot(op, node, **kwargs):
x_format = x.type.format if x_is_sparse else None x_format = x.type.format if x_is_sparse else None
y_format = y.type.format if y_is_sparse else None y_format = y.type.format if y_is_sparse else None
cache_version = 1 cache_version = 2
cache_key = sha256( cache_key = sha256(
str( str(
( (
...@@ -139,12 +139,14 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -139,12 +139,14 @@ def numba_funcify_SparseDot(op, node, **kwargs):
if x_is_sparse and y_is_sparse: if x_is_sparse and y_is_sparse:
# General spmspm algorithm in CSR format # General spmspm algorithm in CSR format
@numba_basic.numba_njit @numba_basic.numba_njit
def _spmspm(n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data): def _spmspm_csr(x, y, n_row, n_col):
# Pass 1 # Pass 1
x_ind = x_ind.view(np.uint32) x_ind = x.indices.view(np.uint32)
y_ind = y_ind.view(np.uint32) y_ind = y.indices.view(np.uint32)
x_ptr = x_ptr.view(np.uint32) x_ptr = x.indptr.view(np.uint32)
y_ptr = y_ptr.view(np.uint32) y_ptr = y.indptr.view(np.uint32)
x_data = x.data
y_data = y.data
output_nnz = 0 output_nnz = 0
mask = np.full(n_col, -1, dtype=np.int32) mask = np.full(n_col, -1, dtype=np.int32)
...@@ -203,42 +205,63 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -203,42 +205,63 @@ def numba_funcify_SparseDot(op, node, **kwargs):
return z_ptr.view(np.int32), z_ind.view(np.int32), z_data return z_ptr.view(np.int32), z_ind.view(np.int32), z_data
formats = (x_format, y_format)
if formats == ("csc", "csc"):
# In all cases, the output is dense when the op is Dot.
@numba_basic.numba_njit @numba_basic.numba_njit
def spmspm(x, y): def spmspm_csc_csc(x, y):
if x_format == "csc" and y_format == "csc": # Swap inputs
# Compute the transpose dot, to avoid costly conversion tocsr()
x, y = y.T, x.T
elif x_format == "csc":
x = x.tocsr()
elif y_format == "csc":
y = y.tocsr()
x_ptr, x_ind, x_data = x.indptr, x.indices, x.data
y_ptr, y_ind, y_data = y.indptr, y.indices, y.data
n_row, n_col = x.shape[0], y.shape[1] n_row, n_col = x.shape[0], y.shape[1]
z_ptr, z_ind, z_data = _spmspm_csr(x=y, y=x, n_row=n_col, n_col=n_row)
output = sp.csc_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
if not z_is_sparse:
return output.toarray()
return output
return spmspm_csc_csc, cache_key
elif formats == ("csc", "csr"):
z_ptr, z_ind, z_data = _spmspm( @numba_basic.numba_njit
n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data def spmspm_csc_csr(x, y):
# Convert csr to csc and swap
n_row, n_col = x.shape[0], y.shape[1]
z_ptr, z_ind, z_data = _spmspm_csr(
x=y.tocsc(), y=x, n_row=n_col, n_col=n_row
) )
output = sp.csc_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
if not z_is_sparse:
return output.toarray()
return output
return spmspm_csc_csr, cache_key
elif formats == ("csr", "csc"):
@numba_basic.numba_njit
def spmspm_csr_csc(x, y):
# Convert csc to csr, no swap
n_row, n_col = x.shape[0], y.shape[1]
z_ptr, z_ind, z_data = _spmspm_csr(
x=x, y=y.tocsr(), n_row=n_row, n_col=n_col
)
output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col)) output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
if not z_is_sparse:
return output.toarray()
return output
if x_format == "csc" and y_format == "csc": return spmspm_csr_csc, cache_key
# We computed the transposed dot in csr, if we transpose the result we get csc else:
output = output.T
# Dot returns a dense result even in spMspM @numba_basic.numba_njit
def spmspm_csr_csr(x, y):
# No conversion, no swap
n_row, n_col = x.shape[0], y.shape[1]
z_ptr, z_ind, z_data = _spmspm_csr(x=x, y=y, n_row=n_row, n_col=n_col)
output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
if not z_is_sparse: if not z_is_sparse:
return output.toarray() return output.toarray()
# StructuredDot returns in the format of 'x'
elif x_format == "csc" and y_format == "csr":
# This is the only case we can't escape a `tocsc()` call
return output.tocsc()
else:
# Output already in the desired format
return output return output
return spmspm, cache_key return spmspm_csr_csr, cache_key
# Only one of 'x' or 'y' is sparse, not both. # Only one of 'x' or 'y' is sparse, not both.
# Before using a general dot(sparse-matrix, dense-matrix) algorithm, # Before using a general dot(sparse-matrix, dense-matrix) algorithm,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论