提交 18fbf45c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba sparse: Avoid explicit conversion in csc @ csc

上级 6a17990a
...@@ -121,6 +121,7 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -121,6 +121,7 @@ def numba_funcify_SparseDot(op, node, **kwargs):
x_format = x.type.format if x_is_sparse else None x_format = x.type.format if x_is_sparse else None
y_format = y.type.format if y_is_sparse else None y_format = y.type.format if y_is_sparse else None
cache_version = 1
cache_key = sha256( cache_key = sha256(
str( str(
( (
...@@ -130,6 +131,7 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -130,6 +131,7 @@ def numba_funcify_SparseDot(op, node, **kwargs):
z_is_sparse, z_is_sparse,
y.type.ndim, y.type.ndim,
y.type.broadcastable, y.type.broadcastable,
cache_version,
) )
).encode() ).encode()
).hexdigest() ).hexdigest()
...@@ -203,10 +205,12 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -203,10 +205,12 @@ def numba_funcify_SparseDot(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def spmspm(x, y): def spmspm(x, y):
if x_format != "csr": if x_format == "csc" and y_format == "csc":
# Compute the transpose dot, to avoid costly conversion tocsr()
x, y = y.T, x.T
elif x_format == "csc":
x = x.tocsr() x = x.tocsr()
elif y_format == "csc":
if y_format != "csr":
y = y.tocsr() y = y.tocsr()
x_ptr, x_ind, x_data = x.indptr, x.indices, x.data x_ptr, x_ind, x_data = x.indptr, x.indices, x.data
...@@ -219,15 +223,20 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -219,15 +223,20 @@ def numba_funcify_SparseDot(op, node, **kwargs):
output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col)) output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
if x_format == "csc" and y_format == "csc":
# We computed the transposed dot in csr, if we transpose the result we get csc
output = output.T
# Dot returns a dense result even in spMspM # Dot returns a dense result even in spMspM
if not z_is_sparse: if not z_is_sparse:
return output.toarray() return output.toarray()
# StructuredDot returns in the format of 'x' # StructuredDot returns in the format of 'x'
if x_format == "csc": elif x_format == "csc" and y_format == "csr":
# This is the only case we can't escape a `tocsc()` call
return output.tocsc() return output.tocsc()
else:
return output # Output already in the desired format
return output
return spmspm, cache_key return spmspm, cache_key
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论