提交 87a6ced8 authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement GetItem2d sparse Op in Numba backend

上级 8c7518af
......@@ -17,6 +17,7 @@ from pytensor.sparse import (
ColScaleCSC,
CSMProperties,
DenseFromSparse,
GetItem2d,
GetItem2Lists,
GetItem2ListsGrad,
GetItemList,
......@@ -621,3 +622,122 @@ def numba_funcify_GetItem2ListsGrad(op, node, **kwargs):
return get_item_2lists_grad_csr(x, ind1, ind2, gz).tocsc()
return get_item_2lists_grad_csc
@register_funcify_default_op_cache_key(GetItem2d)
def numba_funcify_GetItem2d(op, node, **kwargs):
input_format = node.inputs[0].type.format
@numba_basic.numba_njit
def normalize_index(idx):
# Slice construction requires scalars or None, but we may receive a 0d array.
return np.asarray(idx).item() if idx is not None else None
@numba_basic.numba_njit
def slice_indices(size, start, stop, step):
start, stop, step = slice(
normalize_index(start), normalize_index(stop), normalize_index(step)
).indices(size)
return np.arange(start, stop, step, dtype=np.int32)
if input_format == "csr":
@numba_basic.numba_njit
def get_item_2d_csr(x, start1, stop1, step1, start2, stop2, step2):
# Reproduces SciPy when running:
# x[start1:stop1:step1, start2:stop2:step2]
n_rows, n_cols = x.shape
selected_rows = slice_indices(n_rows, start1, stop1, step1)
selected_cols = slice_indices(n_cols, start2, stop2, step2)
out_n_rows = len(selected_rows)
out_n_cols = len(selected_cols)
col_map = np.full(n_cols, -1, dtype=np.int32)
for out_col in range(out_n_cols):
col_map[selected_cols[out_col]] = out_col
x_indptr = x.indptr.view(np.uint32)
x_indices = x.indices.view(np.uint32)
out_indptr = np.empty(out_n_rows + 1, dtype=np.int32)
out_indptr[0] = 0
total_nnz = 0
for out_row in range(out_n_rows):
src_row = selected_rows[out_row]
row_nnz = 0
for data_idx in range(x_indptr[src_row], x_indptr[src_row + 1]):
src_col = x_indices[data_idx]
if col_map[src_col] != -1:
row_nnz += 1
total_nnz += row_nnz
out_indptr[out_row + 1] = total_nnz
out_data = np.empty(total_nnz, dtype=x.data.dtype)
out_indices = np.empty(total_nnz, dtype=np.int32)
for out_row in range(out_n_rows):
src_row = selected_rows[out_row]
dst = out_indptr[out_row]
for data_idx in range(x_indptr[src_row], x_indptr[src_row + 1]):
src_col = x_indices[data_idx]
out_col = col_map[src_col]
if out_col != -1:
out_data[dst] = x.data[data_idx]
out_indices[dst] = out_col
dst += 1
return sp.sparse.csr_matrix(
(out_data, out_indices, out_indptr), shape=(out_n_rows, out_n_cols)
)
return get_item_2d_csr
@numba_basic.numba_njit
def get_item_2d_csc(x, start1, stop1, step1, start2, stop2, step2):
# Reproduces SciPy when running:
# x[start1:stop1:step1, start2:stop2:step2]
n_rows, n_cols = x.shape
selected_rows = slice_indices(n_rows, start1, stop1, step1)
selected_cols = slice_indices(n_cols, start2, stop2, step2)
out_n_rows = selected_rows.shape[0]
out_n_cols = selected_cols.shape[0]
row_map = np.full(n_rows, -1, dtype=np.int32)
for out_row in range(out_n_rows):
row_map[selected_rows[out_row]] = out_row
x_indptr = x.indptr.view(np.uint32)
x_indices = x.indices.view(np.uint32)
out_indptr = np.empty(out_n_cols + 1, dtype=np.int32)
out_indptr[0] = 0
total_nnz = 0
for out_col in range(out_n_cols):
src_col = selected_cols[out_col]
col_nnz = 0
for data_idx in range(x_indptr[src_col], x_indptr[src_col + 1]):
src_row = x_indices[data_idx]
if row_map[src_row] != -1:
col_nnz += 1
total_nnz += col_nnz
out_indptr[out_col + 1] = total_nnz
out_data = np.empty(total_nnz, dtype=x.data.dtype)
out_indices = np.empty(total_nnz, dtype=np.int32)
for out_col in range(out_n_cols):
src_col = selected_cols[out_col]
dst = out_indptr[out_col]
for data_idx in range(x_indptr[src_col], x_indptr[src_col + 1]):
src_row = x_indices[data_idx]
out_row = row_map[src_row]
if out_row != -1:
out_data[dst] = x.data[data_idx]
out_indices[dst] = out_row
dst += 1
return sp.sparse.csc_matrix(
(out_data, out_indices, out_indptr), shape=(out_n_rows, out_n_cols)
)
return get_item_2d_csc
......@@ -529,6 +529,34 @@ def test_sparse_get_item_2lists(format):
compare_numba_and_py_sparse([x, ind1, ind2], z, [x_test, ind1_test, ind2_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_2d(format):
x = ps.matrix(format, name="x", shape=(100, 97), dtype=config.floatX)
a = pt.iscalar("a")
b = pt.iscalar("b")
c = pt.iscalar("c")
d = pt.iscalar("d")
e = pt.iscalar("e")
f = pt.iscalar("f")
z1 = x[a:b:e, c:d:f]
z2 = x[a:b:e]
z3 = x[:a, :b]
z4 = x[:, a:]
z5 = x[1:10:2, 10:20:3]
z6 = x[10:1:-2, 15:2:-3]
x_test = sp.sparse.random(100, 97, density=0.4, format=format, dtype=config.floatX)
compare_numba_and_py_sparse(
[x, a, b, c, d, e, f],
[z1, z2, z3, z4],
[x_test, 1, 5, 10, 15, 2, 3],
)
compare_numba_and_py_sparse([x], z5, [x_test])
compare_numba_and_py_sparse([x], z6, [x_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
@pytest.mark.parametrize(
("ind1_test", "ind2_test"),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论