提交 ab7a4bdd authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement GetItemScalar sparse Op in Numba backend

上级 87a6ced8
......@@ -22,6 +22,7 @@ from pytensor.sparse import (
GetItem2ListsGrad,
GetItemList,
GetItemListGrad,
GetItemScalar,
HStack,
RowScaleCSC,
SparseFromDense,
......@@ -741,3 +742,73 @@ def numba_funcify_GetItem2d(op, node, **kwargs):
)
return get_item_2d_csc
@register_funcify_default_op_cache_key(GetItemScalar)
def numba_funcify_GetItemScalar(op, node, **kwargs):
input_format = node.inputs[0].type.format
out_dtype = np.dtype(node.outputs[0].type.dtype)
if input_format == "csr":
@numba_basic.numba_njit
def get_item_scalar_csr(x, ind1, ind2):
n_rows, n_cols = x.shape
row_idx = np.asarray(ind1).item()
if row_idx < 0:
row_idx += n_rows
if row_idx < 0 or row_idx >= n_rows:
raise IndexError("row index out of bounds")
col_idx = np.asarray(ind2).item()
if col_idx < 0:
col_idx += n_cols
if col_idx < 0 or col_idx >= n_cols:
raise IndexError("column index out of bounds")
row_idx = np.uint32(row_idx)
col_idx = np.uint32(col_idx)
indptr = x.indptr.view(np.uint32)
indices = x.indices.view(np.uint32)
out = 0
for data_idx in range(indptr[row_idx], indptr[row_idx + 1]):
# Duplicate sparse entries must accumulate like scipy indexing.
if indices[data_idx] == col_idx:
out += x.data[data_idx]
return np.asarray(out, dtype=out_dtype)
return get_item_scalar_csr
@numba_basic.numba_njit
def get_item_scalar_csc(x, ind1, ind2):
n_rows, n_cols = x.shape
row_idx = np.asarray(ind1).item()
if row_idx < 0:
row_idx += n_rows
if row_idx < 0 or row_idx >= n_rows:
raise IndexError("row index out of bounds")
col_idx = np.asarray(ind2).item()
if col_idx < 0:
col_idx += n_cols
if col_idx < 0 or col_idx >= n_cols:
raise IndexError("column index out of bounds")
row_idx = np.uint32(row_idx)
col_idx = np.uint32(col_idx)
indptr = x.indptr.view(np.uint32)
indices = x.indices.view(np.uint32)
out = 0
for data_idx in range(indptr[col_idx], indptr[col_idx + 1]):
# Duplicate sparse entries must accumulate like scipy indexing.
if indices[data_idx] == row_idx:
out += x.data[data_idx]
return np.asarray(out, dtype=out_dtype)
return get_item_scalar_csc
......@@ -620,3 +620,37 @@ def test_sparse_get_item_2lists_grad_wrong_index(format, ind1_test, ind2_test):
with pytest.raises(IndexError):
fn(x_test, ind1_test, ind2_test, gz_test)
@pytest.mark.parametrize("format", ("csr", "csc"))
@pytest.mark.parametrize(("row_idx", "col_idx"), [(3, 2), (-1, -2)])
def test_sparse_get_item_scalar(format, row_idx, col_idx):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
row = pt.iscalar("row")
col = pt.iscalar("col")
z_var = x[row, col]
z_lit = x[3, 2]
z_lit_neg = x[-1, -2]
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
compare_numba_and_py_sparse([x, row, col], z_var, [x_test, row_idx, col_idx])
compare_numba_and_py_sparse([x], z_lit, [x_test])
compare_numba_and_py_sparse([x], z_lit_neg, [x_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_scalar_wrong_index(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
row = pt.iscalar("row")
col = pt.iscalar("col")
z = x[row, col]
fn = function([x, row, col], z, mode="NUMBA")
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
with pytest.raises(IndexError, match="row index out of bounds"):
fn(x_test, 6, 0)
with pytest.raises(IndexError, match="column index out of bounds"):
fn(x_test, 0, 5)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论