提交 6c949b3a authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement GetItemList and GetItemListGrad sparse Ops in Numba backend

上级 7ad266d6
...@@ -17,6 +17,8 @@ from pytensor.sparse import ( ...@@ -17,6 +17,8 @@ from pytensor.sparse import (
ColScaleCSC, ColScaleCSC,
CSMProperties, CSMProperties,
DenseFromSparse, DenseFromSparse,
GetItemList,
GetItemListGrad,
HStack, HStack,
RowScaleCSC, RowScaleCSC,
SparseFromDense, SparseFromDense,
...@@ -283,3 +285,318 @@ def numba_funcify_RowScaleCSC(op, node, **kwargs): ...@@ -283,3 +285,318 @@ def numba_funcify_RowScaleCSC(op, node, **kwargs):
) )
return row_scale_csc return row_scale_csc
@register_funcify_default_op_cache_key(GetItemList)
def numba_funcify_GetItemList(op, node, **kwargs):
output_format = node.outputs[0].type.format
@numba_basic.numba_njit
def get_item_list_csr(x, idxs):
# Reproduces SciPy when running:
# x_sparse[idxs]
x_csr = x.tocsr()
n_rows, n_cols = x_csr.shape
n_out_rows = idxs.shape[0]
x_data = x_csr.data
x_indices = x_csr.indices.view(np.uint32)
x_indptr = x_csr.indptr.view(np.uint32)
out_indptr = np.empty(n_out_rows + 1, dtype=np.int32)
out_indptr[0] = 0
norm_idx = np.empty(n_out_rows, dtype=np.int32)
# Normalize (negative) indices and compute output indptr in the same pass.
total_nnz = 0
for out_row_idx in range(n_out_rows):
row_idx = idxs[out_row_idx]
if row_idx < 0:
row_idx += n_rows
if row_idx < 0 or row_idx >= n_rows:
raise IndexError("row index out of bounds")
norm_row_idx = row_idx
norm_idx[out_row_idx] = norm_row_idx
total_nnz += x_indptr[norm_row_idx + 1] - x_indptr[norm_row_idx]
out_indptr[out_row_idx + 1] = total_nnz
# Once the number of non-zero elements is known, allocate data and indices vectors.
out_data = np.empty(total_nnz, dtype=x_data.dtype)
out_indices = np.empty(total_nnz, dtype=np.int32)
# For the selected rows, copy data and indices from source to destination.
# Duplicated entries will lead to duplicated rows.
for out_row_idx in range(n_out_rows):
row_idx = norm_idx[out_row_idx]
src_start = x_indptr[row_idx]
src_stop = x_indptr[row_idx + 1]
dst_start = out_indptr[out_row_idx]
# We could have used slicing, but numba is faster with explicit loops.
dst_idx = dst_start
for src_i in range(src_start, src_stop):
out_data[dst_idx] = x_data[src_i]
out_indices[dst_idx] = x_indices[src_i]
dst_idx += 1
return sp.sparse.csr_matrix(
(out_data, out_indices, out_indptr), shape=(n_out_rows, n_cols)
)
if output_format == "csr":
return get_item_list_csr
@numba_basic.numba_njit
def get_item_list_csc(x, idx):
return get_item_list_csr(x, idx).tocsc()
return get_item_list_csc
@register_funcify_default_op_cache_key(GetItemListGrad)
def numba_funcify_GetItemListGrad(op, node, **kwargs):
output_format = node.outputs[0].type.format
out_dtype = np.dtype(node.outputs[0].type.dtype)
@numba_basic.numba_njit
def get_item_list_grad_csr(x, idxs, gz):
# Reproduces SciPy when running:
# y = [csc|csr]_matrix(x.shape)
# for i in range(len(idxs)):
# y[idxs[i]] = gz[i]
n_rows, n_cols = x.shape
n_out_rows = idxs.shape[0]
gz_n_rows = gz.shape[0]
# Normalize (negative) indices and build row_to_pos mapping.
norm_idx = np.empty(n_out_rows, dtype=np.int32)
row_to_pos = np.full(n_rows, -1, dtype=np.int32)
touched_n_rows = 0
for src_row in range(n_out_rows):
row_idx = idxs[src_row]
if row_idx < 0:
row_idx += n_rows
if row_idx < 0 or row_idx >= n_rows:
raise IndexError("row index out of bounds")
if src_row >= gz_n_rows:
raise IndexError("gradient row index out of bounds")
norm_idx[src_row] = row_idx
if row_to_pos[row_idx] == -1:
row_to_pos[row_idx] = touched_n_rows
touched_n_rows += 1
# Process gz in CSR format.
gz_csr = gz.tocsr()
gz_data = gz_csr.data
gz_indices = gz_csr.indices.view(np.uint32)
gz_indptr = gz_csr.indptr.view(np.uint32)
# Row-wise buffers that reproduce SciPy row-assignment behavior:
# repeated assignments keep the union of touched columns and turn
# missing entries into explicit zeros.
row_data = np.zeros((touched_n_rows, n_cols), dtype=out_dtype)
row_mask = np.zeros((touched_n_rows, n_cols), dtype=np.bool_)
row_seen = np.zeros(touched_n_rows, dtype=np.bool_)
for src_row in range(n_out_rows):
row_idx = norm_idx[src_row]
row_pos = row_to_pos[row_idx]
if row_seen[row_pos]:
for col_idx in range(n_cols):
if row_mask[row_pos, col_idx]:
row_data[row_pos, col_idx] = 0
else:
row_seen[row_pos] = True
for i in range(gz_indptr[src_row], gz_indptr[src_row + 1]):
col_idx = gz_indices[i]
row_data[row_pos, col_idx] = gz_data[i]
row_mask[row_pos, col_idx] = True
# Compute out_indptr by counting True entries in row_mask row-by-row.
out_indptr = np.empty(n_rows + 1, dtype=np.int32)
out_indptr[0] = 0
total_nnz = 0
for row_idx in range(n_rows):
row_pos = row_to_pos[row_idx]
if row_pos >= 0 and row_seen[row_pos]:
row_nnz = 0
for col_idx in range(n_cols):
if row_mask[row_pos, col_idx]:
row_nnz += 1
total_nnz += row_nnz
out_indptr[row_idx + 1] = total_nnz
# Once the number of non-zero elements is known, allocate data and indices vectors.
out_data = np.empty(total_nnz, dtype=out_dtype)
out_indices = np.empty(total_nnz, dtype=np.int32)
# Populate output indices and data, by row, scanning columns in ascending order.
out_pos = 0
for row_idx in range(n_rows):
row_pos = row_to_pos[row_idx]
if row_pos < 0 or not row_seen[row_pos]:
continue
for col_idx in range(n_cols):
if row_mask[row_pos, col_idx]:
out_indices[out_pos] = col_idx
out_data[out_pos] = row_data[row_pos, col_idx]
out_pos += 1
return sp.sparse.csr_matrix(
(out_data, out_indices, out_indptr), shape=(n_rows, n_cols)
)
if output_format == "csr":
return get_item_list_grad_csr
@numba_basic.numba_njit
def get_item_list_grad_csc(x, idx, gz):
return get_item_list_grad_csr(x, idx, gz).tocsc()
return get_item_list_grad_csc
<<<<<<< HEAD
=======
@register_funcify_default_op_cache_key(GetItem2Lists)
def numba_funcify_GetItem2Lists(op, node, **kwargs):
out_dtype = np.dtype(node.outputs[0].type.dtype)
@numba_basic.numba_njit
def get_item_2lists(x, ind1, ind2):
x_csr = x.tocsr()
n_rows, n_cols = x_csr.shape
if ind1.shape != ind2.shape:
raise ValueError("shape mismatch in row/column indices")
out_size = ind1.shape[0]
out = np.zeros(out_size, dtype=out_dtype)
x_data = x_csr.data
x_indices = x_csr.indices.view(np.uint32)
x_indptr = x_csr.indptr.view(np.uint32)
for i in range(out_size):
row_idx = ind1[i]
if row_idx < 0:
row_idx += n_rows
if row_idx < 0 or row_idx >= n_rows:
raise IndexError("row index out of bounds")
col_idx = ind2[i]
if col_idx < 0:
col_idx += n_cols
if col_idx < 0 or col_idx >= n_cols:
raise IndexError("column index out of bounds")
col_idx_u32 = np.uint32(col_idx)
for data_idx in range(x_indptr[row_idx], x_indptr[row_idx + 1]):
if x_indices[data_idx] == col_idx_u32:
# Duplicate sparse entries must accumulate like scipy indexing.
out[i] += x_data[data_idx]
return out
return get_item_2lists
@register_funcify_default_op_cache_key(GetItem2ListsGrad)
def numba_funcify_GetItem2ListsGrad(op, node, **kwargs):
output_format = node.outputs[0].type.format
out_dtype = np.dtype(node.outputs[0].type.dtype)
@numba_basic.numba_njit
def get_item_2lists_grad_csr(x, ind1, ind2, gz):
n_rows, n_cols = x.shape
n_assignments = ind1.shape[0]
if ind2.shape[0] != n_assignments:
raise ValueError("shape mismatch in row/column indices")
if gz.shape[0] < n_assignments:
raise IndexError("gradient index out of bounds")
norm_row = np.empty(n_assignments, dtype=np.int32)
norm_col = np.empty(n_assignments, dtype=np.int32)
row_to_pos = np.full(n_rows, -1, dtype=np.int32)
touched_n_rows = 0
for i in range(n_assignments):
row_idx = ind1[i]
if row_idx < 0:
row_idx += n_rows
if row_idx < 0 or row_idx >= n_rows:
raise IndexError("row index out of bounds")
col_idx = ind2[i]
if col_idx < 0:
col_idx += n_cols
if col_idx < 0 or col_idx >= n_cols:
raise IndexError("column index out of bounds")
norm_row[i] = row_idx
norm_col[i] = col_idx
if row_to_pos[row_idx] == -1:
row_to_pos[row_idx] = touched_n_rows
touched_n_rows += 1
# Build row-wise buffers for touched rows. Repeated writes overwrite values.
row_data = np.zeros((touched_n_rows, n_cols), dtype=out_dtype)
row_mask = np.zeros((touched_n_rows, n_cols), dtype=np.bool_)
for i in range(n_assignments):
row_pos = row_to_pos[norm_row[i]]
col_idx = norm_col[i]
row_data[row_pos, col_idx] = gz[i]
row_mask[row_pos, col_idx] = True
out_indptr = np.empty(n_rows + 1, dtype=np.int32)
out_indptr[0] = 0
total_nnz = 0
for row_idx in range(n_rows):
row_pos = row_to_pos[row_idx]
if row_pos >= 0:
row_nnz = 0
for col_idx in range(n_cols):
if row_mask[row_pos, col_idx]:
row_nnz += 1
total_nnz += row_nnz
out_indptr[row_idx + 1] = total_nnz
out_data = np.empty(total_nnz, dtype=out_dtype)
out_indices = np.empty(total_nnz, dtype=np.int32)
out_pos = 0
for row_idx in range(n_rows):
row_pos = row_to_pos[row_idx]
if row_pos < 0:
continue
for col_idx in range(n_cols):
if row_mask[row_pos, col_idx]:
out_indices[out_pos] = col_idx
out_data[out_pos] = row_data[row_pos, col_idx]
out_pos += 1
return sp.sparse.csr_matrix(
(out_data, out_indices, out_indptr), shape=(n_rows, n_cols)
)
if output_format == "csr":
return get_item_2lists_grad_csr
@numba_basic.numba_njit
def get_item_2lists_grad_csc(x, ind1, ind2, gz):
return get_item_2lists_grad_csr(x, ind1, ind2, gz).tocsc()
return get_item_2lists_grad_csc
>>>>>>> fb1d09134 (Better comments for GetItemList and GetItemListGrad)
...@@ -454,3 +454,62 @@ def test_sparse_row_scale(format): ...@@ -454,3 +454,62 @@ def test_sparse_row_scale(format):
v_test = np.random.random(7).astype(config.floatX) v_test = np.random.random(7).astype(config.floatX)
compare_numba_and_py_sparse([x, v], z, [x_test, v_test]) compare_numba_and_py_sparse([x, v], z, [x_test, v_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_list(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
idx = pt.ivector("idx")
z = ps.get_item_list(x, idx)
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
idx_test = np.asarray([0, 2, 5, 2], dtype=np.int32)
compare_numba_and_py_sparse([x, idx], z, [x_test, idx_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_list_wrong_index(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
idx = pt.ivector("idx")
z = ps.get_item_list(x, idx)
fn = function([x, idx], z, mode="NUMBA")
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
idx_test = np.asarray([0, 6], dtype=np.int32)
with pytest.raises(IndexError):
fn(x_test, idx_test)
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_list_grad(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
idx = pt.ivector("idx")
gz = ps.matrix(format, name="gz", shape=(4, 5), dtype=config.floatX)
z = ps.get_item_list_grad(x, idx, gz)
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
gz_test = sp.sparse.random(4, 5, density=0.4, format=format, dtype=config.floatX)
idx_test = np.asarray([0, 2, 5, 2], dtype=np.int32)
with pytest.warns(sp.sparse.SparseEfficiencyWarning):
# GetItemListGrad.perform does sparse row assignment into an initially empty sparse
# matrix, which changes sparsity structure incrementally and triggers the warning.
compare_numba_and_py_sparse([x, idx, gz], z, [x_test, idx_test, gz_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_list_grad_wrong_index(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
idx = pt.ivector("idx")
gz = ps.matrix(format, name="gz", shape=(2, 5), dtype=config.floatX)
z = ps.get_item_list_grad(x, idx, gz)
fn = function([x, idx, gz], z, mode="NUMBA")
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
gz_test = sp.sparse.random(2, 5, density=0.4, format=format, dtype=config.floatX)
idx_test = np.asarray([0, 6], dtype=np.int32)
with pytest.raises(IndexError):
fn(x_test, idx_test, gz_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论