提交 a36dc8ed authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement SpSum in numba backend

上级 54db5ad0
......@@ -14,12 +14,24 @@ from pytensor.sparse import (
Dot,
SparseDenseMultiply,
SparseDenseVectorMultiply,
SpSum,
StructuredDot,
StructuredDotGradCSC,
StructuredDotGradCSR,
)
@register_funcify_default_op_cache_key(SpSum)
def numba_funcify_SpSum(op, node, **kwargs):
axis = op.axis
@numba_basic.numba_njit
def perform(x):
return x.sum(axis)
return perform
@register_funcify_default_op_cache_key(SparseDenseMultiply)
@register_funcify_default_op_cache_key(SparseDenseVectorMultiply)
def numba_funcify_SparseDenseMultiply(op, node, **kwargs):
......
......@@ -432,3 +432,60 @@ def overload_toarray(matrix):
return to_array
case _:
return
@overload_method(CSMatrixType, "sum")
def overload_sum(matrix, axis):
# 'axis' can be either None, 0, or 1.
if axis is types.none:
def sum_scalar(matrix, axis):
return np.asarray(np.sum(matrix.data))
return sum_scalar
match matrix:
case CSRMatrixType():
def sum_csr(matrix, axis):
indptr = matrix.indptr.view(np.uint32)
indices = matrix.indices.view(np.uint32)
data = matrix.data
n_row = matrix.shape[0]
n_col = matrix.shape[1]
if axis == 0:
col_sums = np.zeros(n_col, dtype=data.dtype)
for i in range(len(data)):
col_sums[indices[i]] += data[i]
return col_sums
else:
row_sums = np.zeros(n_row, dtype=data.dtype)
for i in range(n_row):
row_sums[i] = np.sum(data[indptr[i] : indptr[i + 1]])
return row_sums
return sum_csr
case CSCMatrixType():
def sum_csc(matrix, axis):
indptr = matrix.indptr.view(np.uint32)
indices = matrix.indices.view(np.uint32)
data = matrix.data
n_row = matrix.shape[0]
n_col = matrix.shape[1]
if axis == 0:
col_sums = np.zeros(n_col, dtype=data.dtype)
for j in range(n_col):
col_sums[j] = np.sum(data[indptr[j] : indptr[j + 1]])
return col_sums
else:
row_sums = np.zeros(n_row, dtype=data.dtype)
for i in range(len(data)):
row_sums[indices[i]] += data[i]
return row_sums
return sum_csc
case _:
return
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论