提交 db215037 authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement Diag sparse Op in Numba backend

上级 a332c8b2
......@@ -17,6 +17,7 @@ from pytensor.sparse import (
ColScaleCSC,
CSMProperties,
DenseFromSparse,
Diag,
GetItem2d,
GetItem2Lists,
GetItem2ListsGrad,
......@@ -825,3 +826,51 @@ def numba_funcify_Neg(op, node, **kwargs):
return z
return neg
@register_funcify_default_op_cache_key(Diag)
def numba_funcify_Diag(op, node, **kwargs):
input_format = node.inputs[0].type.format
out_dtype = node.outputs[0].type.dtype
if input_format == "csr":
@numba_basic.numba_njit
def diag_csr(x):
n_rows, n_cols = x.shape
if n_rows != n_cols:
raise ValueError("Diag only apply on square matrix")
indptr = x.indptr.view(np.uint32)
indices = x.indices.view(np.uint32)
out = np.zeros(n_rows, dtype=out_dtype)
for row_idx in range(n_rows):
for data_idx in range(indptr[row_idx], indptr[row_idx + 1]):
# Duplicate entries on the diagonal must accumulate.
if indices[data_idx] == row_idx:
out[row_idx] += x.data[data_idx]
return out
return diag_csr
@numba_basic.numba_njit
def diag_csc(x):
n_rows, n_cols = x.shape
if n_rows != n_cols:
raise ValueError("Diag only apply on square matrix")
indptr = x.indptr.view(np.uint32)
indices = x.indices.view(np.uint32)
out = np.zeros(n_cols, dtype=out_dtype)
for col_idx in range(n_cols):
for data_idx in range(indptr[col_idx], indptr[col_idx + 1]):
# Duplicate entries on the diagonal must accumulate.
if indices[data_idx] == col_idx:
out[col_idx] += x.data[data_idx]
return out
return diag_csc
......@@ -664,3 +664,25 @@ def test_sparse_neg(format):
x_test = sp.sparse.random(7, 6, density=0.4, format=format, dtype=config.floatX)
compare_numba_and_py_sparse([x], z, [x_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_diag(format):
x = ps.matrix(format, name="x", shape=(8, 8), dtype=config.floatX)
z = ps.diag(x)
x_test = sp.sparse.random(8, 8, density=0.4, format=format, dtype=config.floatX)
compare_numba_and_py_sparse([x], z, [x_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_diag_not_square_raises(format):
x = ps.matrix(format, name="x", shape=(8, 6), dtype=config.floatX)
z = ps.diag(x)
fn = function([x], z, mode="NUMBA")
x_test = sp.sparse.random(8, 6, density=0.4, format=format, dtype=config.floatX)
with pytest.raises(ValueError, match="Diag only apply on square matrix"):
fn(x_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论