提交 18fbf45c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba sparse: Avoid explicit conversion in csc @ csc

上级 6a17990a
......@@ -121,6 +121,7 @@ def numba_funcify_SparseDot(op, node, **kwargs):
x_format = x.type.format if x_is_sparse else None
y_format = y.type.format if y_is_sparse else None
cache_version = 1
cache_key = sha256(
str(
(
......@@ -130,6 +131,7 @@ def numba_funcify_SparseDot(op, node, **kwargs):
z_is_sparse,
y.type.ndim,
y.type.broadcastable,
cache_version,
)
).encode()
).hexdigest()
......@@ -203,10 +205,12 @@ def numba_funcify_SparseDot(op, node, **kwargs):
@numba_basic.numba_njit
def spmspm(x, y):
if x_format != "csr":
if x_format == "csc" and y_format == "csc":
# Compute the transpose dot, to avoid costly conversion tocsr()
x, y = y.T, x.T
elif x_format == "csc":
x = x.tocsr()
if y_format != "csr":
elif y_format == "csc":
y = y.tocsr()
x_ptr, x_ind, x_data = x.indptr, x.indices, x.data
......@@ -219,15 +223,20 @@ def numba_funcify_SparseDot(op, node, **kwargs):
output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
if x_format == "csc" and y_format == "csc":
# We computed the transposed dot in csr, if we transpose the result we get csc
output = output.T
# Dot returns a dense result even in spMspM
if not z_is_sparse:
return output.toarray()
# StructuredDot returns in the format of 'x'
if x_format == "csc":
elif x_format == "csc" and y_format == "csr":
# This is the only case we can't escape a `tocsc()` call
return output.tocsc()
return output
else:
# Output already in the desired format
return output
return spmspm, cache_key
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论