提交 557307a6 authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Numba sparse: handle dot product between matrices more granularly to reduce the…

Numba sparse: handle dot product between matrices more granularly to reduce the number of of format conversions
上级 18fbf45c
...@@ -121,7 +121,7 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -121,7 +121,7 @@ def numba_funcify_SparseDot(op, node, **kwargs):
x_format = x.type.format if x_is_sparse else None x_format = x.type.format if x_is_sparse else None
y_format = y.type.format if y_is_sparse else None y_format = y.type.format if y_is_sparse else None
cache_version = 1 cache_version = 2
cache_key = sha256( cache_key = sha256(
str( str(
( (
...@@ -139,12 +139,14 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -139,12 +139,14 @@ def numba_funcify_SparseDot(op, node, **kwargs):
if x_is_sparse and y_is_sparse: if x_is_sparse and y_is_sparse:
# General spmspm algorithm in CSR format # General spmspm algorithm in CSR format
@numba_basic.numba_njit @numba_basic.numba_njit
def _spmspm(n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data): def _spmspm_csr(x, y, n_row, n_col):
# Pass 1 # Pass 1
x_ind = x_ind.view(np.uint32) x_ind = x.indices.view(np.uint32)
y_ind = y_ind.view(np.uint32) y_ind = y.indices.view(np.uint32)
x_ptr = x_ptr.view(np.uint32) x_ptr = x.indptr.view(np.uint32)
y_ptr = y_ptr.view(np.uint32) y_ptr = y.indptr.view(np.uint32)
x_data = x.data
y_data = y.data
output_nnz = 0 output_nnz = 0
mask = np.full(n_col, -1, dtype=np.int32) mask = np.full(n_col, -1, dtype=np.int32)
...@@ -203,42 +205,63 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -203,42 +205,63 @@ def numba_funcify_SparseDot(op, node, **kwargs):
return z_ptr.view(np.int32), z_ind.view(np.int32), z_data return z_ptr.view(np.int32), z_ind.view(np.int32), z_data
@numba_basic.numba_njit formats = (x_format, y_format)
def spmspm(x, y): if formats == ("csc", "csc"):
if x_format == "csc" and y_format == "csc": # In all cases, the output is dense when the op is Dot.
# Compute the transpose dot, to avoid costly conversion tocsr() @numba_basic.numba_njit
x, y = y.T, x.T def spmspm_csc_csc(x, y):
elif x_format == "csc": # Swap inputs
x = x.tocsr() n_row, n_col = x.shape[0], y.shape[1]
elif y_format == "csc": z_ptr, z_ind, z_data = _spmspm_csr(x=y, y=x, n_row=n_col, n_col=n_row)
y = y.tocsr() output = sp.csc_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
if not z_is_sparse:
x_ptr, x_ind, x_data = x.indptr, x.indices, x.data return output.toarray()
y_ptr, y_ind, y_data = y.indptr, y.indices, y.data return output
n_row, n_col = x.shape[0], y.shape[1]
return spmspm_csc_csc, cache_key
z_ptr, z_ind, z_data = _spmspm( elif formats == ("csc", "csr"):
n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data
) @numba_basic.numba_njit
def spmspm_csc_csr(x, y):
# Convert csr to csc and swap
n_row, n_col = x.shape[0], y.shape[1]
z_ptr, z_ind, z_data = _spmspm_csr(
x=y.tocsc(), y=x, n_row=n_col, n_col=n_row
)
output = sp.csc_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
if not z_is_sparse:
return output.toarray()
return output
output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col)) return spmspm_csc_csr, cache_key
elif formats == ("csr", "csc"):
@numba_basic.numba_njit
def spmspm_csr_csc(x, y):
# Convert csc to csr, no swap
n_row, n_col = x.shape[0], y.shape[1]
z_ptr, z_ind, z_data = _spmspm_csr(
x=x, y=y.tocsr(), n_row=n_row, n_col=n_col
)
output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
if not z_is_sparse:
return output.toarray()
return output
if x_format == "csc" and y_format == "csc": return spmspm_csr_csc, cache_key
# We computed the transposed dot in csr, if we transpose the result we get csc else:
output = output.T
# Dot returns a dense result even in spMspM @numba_basic.numba_njit
if not z_is_sparse: def spmspm_csr_csr(x, y):
return output.toarray() # No conversion, no swap
# StructuredDot returns in the format of 'x' n_row, n_col = x.shape[0], y.shape[1]
elif x_format == "csc" and y_format == "csr": z_ptr, z_ind, z_data = _spmspm_csr(x=x, y=y, n_row=n_row, n_col=n_col)
# This is the only case we can't escape a `tocsc()` call output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
return output.tocsc() if not z_is_sparse:
else: return output.toarray()
# Output already in the desired format
return output return output
return spmspm, cache_key return spmspm_csr_csr, cache_key
# Only one of 'x' or 'y' is sparse, not both. # Only one of 'x' or 'y' is sparse, not both.
# Before using a general dot(sparse-matrix, dense-matrix) algorithm, # Before using a general dot(sparse-matrix, dense-matrix) algorithm,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论