提交 7ad266d6 authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement ColScaleCSC and RowScaleCSC sparse Ops in Numba backend

上级 4765cdad
...@@ -14,9 +14,11 @@ from pytensor.link.numba.dispatch.sparse.variable import CSMatrixType ...@@ -14,9 +14,11 @@ from pytensor.link.numba.dispatch.sparse.variable import CSMatrixType
from pytensor.sparse import ( from pytensor.sparse import (
CSM, CSM,
Cast, Cast,
ColScaleCSC,
CSMProperties, CSMProperties,
DenseFromSparse, DenseFromSparse,
HStack, HStack,
RowScaleCSC,
SparseFromDense, SparseFromDense,
Transpose, Transpose,
VStack, VStack,
...@@ -238,3 +240,46 @@ def numba_funcify_VStack(op, node, **kwargs): ...@@ -238,3 +240,46 @@ def numba_funcify_VStack(op, node, **kwargs):
return vstack_csr(*blocks).tocsc() return vstack_csr(*blocks).tocsc()
return vstack_csc return vstack_csc
@register_funcify_default_op_cache_key(ColScaleCSC)
def numba_funcify_ColScaleCSC(op, node, **kwargs):
@numba_basic.numba_njit
def col_scale_csc(x, v):
n_cols = x.shape[1]
assert v.shape == (n_cols,)
z = x.copy()
z_data = z.data
z_indptr = z.indptr.view(np.uint32)
for col_idx in range(n_cols):
scale = v[col_idx]
# Could use slicing, but numba is usually faster with explicit loops.
for idx in range(z_indptr[col_idx], z_indptr[col_idx + 1]):
z_data[idx] *= scale
return z
return col_scale_csc
@register_funcify_default_op_cache_key(RowScaleCSC)
def numba_funcify_RowScaleCSC(op, node, **kwargs):
@numba_basic.numba_njit
def row_scale_csc(x, v):
n_rows, n_cols = x.shape
assert v.shape == (n_rows,)
indices = x.indices.view(np.uint32)
indptr = x.indptr.view(np.uint32)
z_data = x.data.copy()
for col_idx in range(n_cols):
for idx in range(indptr[col_idx], indptr[col_idx + 1]):
z_data[idx] *= v[indices[idx]]
return sp.sparse.csc_matrix(
(z_data, x.indices.copy(), x.indptr.copy()), shape=x.shape
)
return row_scale_csc
...@@ -432,3 +432,25 @@ def test_sparse_vstack_mismatched_cols_raises(): ...@@ -432,3 +432,25 @@ def test_sparse_vstack_mismatched_cols_raises():
with pytest.raises(ValueError, match="Mismatching dimensions along axis 1"): with pytest.raises(ValueError, match="Mismatching dimensions along axis 1"):
fn(x_test, y_test) fn(x_test, y_test)
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_col_scale(format):
x = ps.matrix(format, name="x", shape=(8, 10), dtype=config.floatX)
v = pt.vector(name="v", shape=(10,), dtype=config.floatX)
z = ps.col_scale(x, v)
x_test = sp.sparse.random(8, 10, density=0.4, format=format, dtype=config.floatX)
s_test = np.random.random(10).astype(config.floatX)
compare_numba_and_py_sparse([x, v], z, [x_test, s_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_row_scale(format):
x = ps.matrix(format, name="x", shape=(7, 10), dtype=config.floatX)
v = pt.vector(name="v", shape=(7,), dtype=config.floatX)
z = ps.row_scale(x, v)
x_test = sp.sparse.random(7, 10, density=0.4, format=format, dtype=config.floatX)
v_test = np.random.random(7).astype(config.floatX)
compare_numba_and_py_sparse([x, v], z, [x_test, v_test])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论