提交 1977c2c0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba sparse: Implement SparseDenseMultiply

上级 fbd7a597
from pytensor.link.numba.dispatch.sparse import basic, variable from pytensor.link.numba.dispatch.sparse import basic, math, variable
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
from pytensor.sparse import SparseDenseMultiply, SparseDenseVectorMultiply
@register_funcify_default_op_cache_key(SparseDenseMultiply)
@register_funcify_default_op_cache_key(SparseDenseVectorMultiply)
def numba_funcify_SparseDenseMultiply(op, node, **kwargs):
x, y = node.inputs
[z] = node.outputs
out_dtype = z.type.dtype
format = z.type.format
same_dtype = x.type.dtype == out_dtype
if y.ndim == 0:
@numba_basic.numba_njit
def sparse_multiply_scalar(x, y):
if same_dtype:
z = x.copy()
else:
z = x.astype(out_dtype)
# Numba doesn't know how to handle in-place mutation / assignment of fields
# z.data *= y
z_data = z.data
z_data *= y
return z
return sparse_multiply_scalar
elif y.ndim == 1:
@numba_basic.numba_njit
def sparse_dense_multiply(x, y):
assert x.shape[1] == y.shape[0]
if same_dtype:
z = x.copy()
else:
z = x.astype(out_dtype)
M, N = x.shape
indices = x.indices
indptr = x.indptr
z_data = z.data
if format == "csc":
for j in range(0, N):
for i_idx in range(indptr[j], indptr[j + 1]):
z_data[i_idx] *= y[j]
return z
else:
for i in range(0, M):
for j_idx in range(indptr[i], indptr[i + 1]):
j = indices[j_idx]
z_data[j_idx] *= y[j]
return z
return sparse_dense_multiply
else: # y.ndim == 2
@numba_basic.numba_njit
def sparse_dense_multiply(x, y):
assert x.shape == y.shape
if same_dtype:
z = x.copy()
else:
z = x.astype(out_dtype)
M, N = x.shape
indices = x.indices
indptr = x.indptr
z_data = z.data
if format == "csc":
for j in range(0, N):
for i_idx in range(indptr[j], indptr[j + 1]):
i = indices[i_idx]
z_data[i_idx] *= y[i, j]
return z
else:
for i in range(0, M):
for j_idx in range(indptr[i], indptr[i + 1]):
j = indices[j_idx]
z_data[j_idx] *= y[i, j]
return z
return sparse_dense_multiply
...@@ -200,10 +200,6 @@ def test_sparse_constant(format, cache): ...@@ -200,10 +200,6 @@ def test_sparse_constant(format, cache):
y_test = np.array([np.pi, np.e, np.euler_gamma]) y_test = np.array([np.pi, np.e, np.euler_gamma])
with config.change_flags(numba__cache=cache): with config.change_flags(numba__cache=cache):
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run SparseDenseVectorMultiply's perform method",
):
compare_numba_and_py_sparse( compare_numba_and_py_sparse(
[y], [y],
[out], [out],
...@@ -254,10 +250,6 @@ def test_simple_graph(format): ...@@ -254,10 +250,6 @@ def test_simple_graph(format):
x_test = sp.sparse.random(3, 3, density=0.5, format=format, random_state=rng) x_test = sp.sparse.random(3, 3, density=0.5, format=format, random_state=rng)
y_test = rng.normal(size=(3,)) y_test = rng.normal(size=(3,))
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run SparseDenseVectorMultiply's perform method",
):
compare_numba_and_py_sparse( compare_numba_and_py_sparse(
[x, y], [x, y],
z, z,
......
import numpy as np
import pytest
import scipy
import pytensor.sparse as ps
import pytensor.tensor as pt
from tests.link.numba.sparse.test_basic import compare_numba_and_py_sparse
pytestmark = pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("format", ["csr", "csc"])
@pytest.mark.parametrize("y_ndim", [0, 1, 2])
def test_sparse_dense_multiply(y_ndim, format):
x = ps.matrix(format, name="x", shape=(3, 3))
y = pt.tensor("y", shape=(3,) * y_ndim)
z = x * y
rng = np.random.default_rng((155, y_ndim, format == "csr"))
x_test = scipy.sparse.random(3, 3, density=0.5, format=format, random_state=rng)
y_test = rng.normal(size=(3,) * y_ndim)
compare_numba_and_py_sparse(
[x, y],
z,
[x_test, y_test],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论