提交 8c7518af authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement GetItem2Lists and GetItem2ListsGrad sparse Ops in Numba backend

上级 6c949b3a
...@@ -17,6 +17,8 @@ from pytensor.sparse import ( ...@@ -17,6 +17,8 @@ from pytensor.sparse import (
ColScaleCSC, ColScaleCSC,
CSMProperties, CSMProperties,
DenseFromSparse, DenseFromSparse,
GetItem2Lists,
GetItem2ListsGrad,
GetItemList, GetItemList,
GetItemListGrad, GetItemListGrad,
HStack, HStack,
...@@ -357,7 +359,7 @@ def numba_funcify_GetItemList(op, node, **kwargs): ...@@ -357,7 +359,7 @@ def numba_funcify_GetItemList(op, node, **kwargs):
@register_funcify_default_op_cache_key(GetItemListGrad) @register_funcify_default_op_cache_key(GetItemListGrad)
def numba_funcify_GetItemListGrad(op, node, **kwargs): def numba_funcify_GetItemListGrad(op, node, **kwargs):
output_format = node.outputs[0].type.format output_format = node.outputs[0].type.format
out_dtype = np.dtype(node.outputs[0].type.dtype) out_dtype = node.outputs[0].type.dtype
@numba_basic.numba_njit @numba_basic.numba_njit
def get_item_list_grad_csr(x, idxs, gz): def get_item_list_grad_csr(x, idxs, gz):
...@@ -461,47 +463,50 @@ def numba_funcify_GetItemListGrad(op, node, **kwargs): ...@@ -461,47 +463,50 @@ def numba_funcify_GetItemListGrad(op, node, **kwargs):
return get_item_list_grad_csr(x, idx, gz).tocsc() return get_item_list_grad_csr(x, idx, gz).tocsc()
return get_item_list_grad_csc return get_item_list_grad_csc
<<<<<<< HEAD
=======
@register_funcify_default_op_cache_key(GetItem2Lists) @register_funcify_default_op_cache_key(GetItem2Lists)
def numba_funcify_GetItem2Lists(op, node, **kwargs): def numba_funcify_GetItem2Lists(op, node, **kwargs):
out_dtype = np.dtype(node.outputs[0].type.dtype) out_dtype = node.outputs[0].type.dtype
@numba_basic.numba_njit @numba_basic.numba_njit
def get_item_2lists(x, ind1, ind2): def get_item_2lists(x, ind1, ind2):
x_csr = x.tocsr() # Reproduces SciPy and NumPy when running:
n_rows, n_cols = x_csr.shape # np.asarray(x[ind1, ind2]).flatten()
if ind1.shape != ind2.shape: if ind1.shape != ind2.shape:
raise ValueError("shape mismatch in row/column indices") raise ValueError("shape mismatch in row/column indices")
# Output vector contains as many elements as the length of the index lists.
out_size = ind1.shape[0] out_size = ind1.shape[0]
out = np.zeros(out_size, dtype=out_dtype) out = np.zeros(out_size, dtype=out_dtype)
x_data = x_csr.data x_csr = x.tocsr()
x_indices = x_csr.indices.view(np.uint32) x_indices = x_csr.indices.view(np.uint32)
x_indptr = x_csr.indptr.view(np.uint32) x_indptr = x_csr.indptr.view(np.uint32)
n_rows, n_cols = x_csr.shape
for i in range(out_size): for i in range(out_size):
# Normalize row index
row_idx = ind1[i] row_idx = ind1[i]
if row_idx < 0: if row_idx < 0:
row_idx += n_rows row_idx += n_rows
if row_idx < 0 or row_idx >= n_rows: if row_idx < 0 or row_idx >= n_rows:
raise IndexError("row index out of bounds") raise IndexError("row index out of bounds")
# Normalize column index
col_idx = ind2[i] col_idx = ind2[i]
if col_idx < 0: if col_idx < 0:
col_idx += n_cols col_idx += n_cols
if col_idx < 0 or col_idx >= n_cols: if col_idx < 0 or col_idx >= n_cols:
raise IndexError("column index out of bounds") raise IndexError("column index out of bounds")
col_idx_u32 = np.uint32(col_idx) row_idx = np.uint32(row_idx)
col_idx = np.uint32(col_idx)
for data_idx in range(x_indptr[row_idx], x_indptr[row_idx + 1]): for data_idx in range(x_indptr[row_idx], x_indptr[row_idx + 1]):
if x_indices[data_idx] == col_idx_u32: if x_indices[data_idx] == col_idx:
# Duplicate sparse entries must accumulate like scipy indexing. # Duplicate sparse entries must accumulate like scipy indexing.
out[i] += x_data[data_idx] out[i] += x_csr.data[data_idx]
return out return out
...@@ -511,31 +516,42 @@ def numba_funcify_GetItem2Lists(op, node, **kwargs): ...@@ -511,31 +516,42 @@ def numba_funcify_GetItem2Lists(op, node, **kwargs):
@register_funcify_default_op_cache_key(GetItem2ListsGrad) @register_funcify_default_op_cache_key(GetItem2ListsGrad)
def numba_funcify_GetItem2ListsGrad(op, node, **kwargs): def numba_funcify_GetItem2ListsGrad(op, node, **kwargs):
output_format = node.outputs[0].type.format output_format = node.outputs[0].type.format
out_dtype = np.dtype(node.outputs[0].type.dtype) out_dtype = node.outputs[0].type.dtype
@numba_basic.numba_njit @numba_basic.numba_njit
def get_item_2lists_grad_csr(x, ind1, ind2, gz): def get_item_2lists_grad_csr(x, ind1, ind2, gz):
n_rows, n_cols = x.shape # Reproduces SciPy when running:
n_assignments = ind1.shape[0] # y = [csc|csr]_matrix(x.shape)
# for i in range(len(ind1)):
# y[(ind1[i], ind2[i])] = gz[i]
#
# Note that gz is a dense vector.
if ind2.shape[0] != n_assignments: if ind1.shape != ind2.shape:
raise ValueError("shape mismatch in row/column indices") raise ValueError("shape mismatch in row/column indices")
n_assignments = ind1.shape[0]
if gz.shape[0] < n_assignments: if gz.shape[0] < n_assignments:
raise IndexError("gradient index out of bounds") raise IndexError("gradient index out of bounds")
norm_row = np.empty(n_assignments, dtype=np.int32) # Vectors with normalized (non-negative) row and column indices
norm_col = np.empty(n_assignments, dtype=np.int32) norm_row = np.empty(n_assignments, dtype=np.uint32)
norm_col = np.empty(n_assignments, dtype=np.uint32)
n_rows, n_cols = x.shape
# Maps original rows to values in [0, ..., touched_n_rows - 1]
row_to_pos = np.full(n_rows, -1, dtype=np.int32) row_to_pos = np.full(n_rows, -1, dtype=np.int32)
touched_n_rows = 0 touched_n_rows = 0
for i in range(n_assignments): for i in range(n_assignments):
# Normalize row idx
row_idx = ind1[i] row_idx = ind1[i]
if row_idx < 0: if row_idx < 0:
row_idx += n_rows row_idx += n_rows
if row_idx < 0 or row_idx >= n_rows: if row_idx < 0 or row_idx >= n_rows:
raise IndexError("row index out of bounds") raise IndexError("row index out of bounds")
# Normalize column idx
col_idx = ind2[i] col_idx = ind2[i]
if col_idx < 0: if col_idx < 0:
col_idx += n_cols col_idx += n_cols
...@@ -552,13 +568,19 @@ def numba_funcify_GetItem2ListsGrad(op, node, **kwargs): ...@@ -552,13 +568,19 @@ def numba_funcify_GetItem2ListsGrad(op, node, **kwargs):
# Build row-wise buffers for touched rows. Repeated writes overwrite values. # Build row-wise buffers for touched rows. Repeated writes overwrite values.
row_data = np.zeros((touched_n_rows, n_cols), dtype=out_dtype) row_data = np.zeros((touched_n_rows, n_cols), dtype=out_dtype)
row_mask = np.zeros((touched_n_rows, n_cols), dtype=np.bool_) row_mask = np.zeros((touched_n_rows, n_cols), dtype=np.bool_)
row_nnz = np.zeros(touched_n_rows, dtype=np.int32)
for i in range(n_assignments): for i in range(n_assignments):
row_pos = row_to_pos[norm_row[i]] row_pos = row_to_pos[norm_row[i]]
col_idx = norm_col[i] col_idx = norm_col[i]
if not row_mask[row_pos, col_idx]:
row_nnz[row_pos] += 1
row_mask[row_pos, col_idx] = True
row_data[row_pos, col_idx] = gz[i] row_data[row_pos, col_idx] = gz[i]
row_mask[row_pos, col_idx] = True
# Build output indptr.
# For touched rows add row_nnz[row_pos] to total_nnz.
# For untouched rows, carry forward the previous total_nnz count.
out_indptr = np.empty(n_rows + 1, dtype=np.int32) out_indptr = np.empty(n_rows + 1, dtype=np.int32)
out_indptr[0] = 0 out_indptr[0] = 0
...@@ -566,26 +588,26 @@ def numba_funcify_GetItem2ListsGrad(op, node, **kwargs): ...@@ -566,26 +588,26 @@ def numba_funcify_GetItem2ListsGrad(op, node, **kwargs):
for row_idx in range(n_rows): for row_idx in range(n_rows):
row_pos = row_to_pos[row_idx] row_pos = row_to_pos[row_idx]
if row_pos >= 0: if row_pos >= 0:
row_nnz = 0 total_nnz += row_nnz[row_pos]
for col_idx in range(n_cols):
if row_mask[row_pos, col_idx]:
row_nnz += 1
total_nnz += row_nnz
out_indptr[row_idx + 1] = total_nnz out_indptr[row_idx + 1] = total_nnz
# Build output data and indices, which need the total number of non-zero elements.
out_data = np.empty(total_nnz, dtype=out_dtype) out_data = np.empty(total_nnz, dtype=out_dtype)
out_indices = np.empty(total_nnz, dtype=np.int32) out_indices = np.empty(total_nnz, dtype=np.int32)
out_pos = 0
# Populate indices and data by storing col_idx and value (row_data[row_pos, col_idx])
# for touched rows/columns.
for row_idx in range(n_rows): for row_idx in range(n_rows):
row_pos = row_to_pos[row_idx] row_pos = row_to_pos[row_idx]
if row_pos < 0: if row_pos < 0:
continue continue
dst = out_indptr[row_idx]
for col_idx in range(n_cols): for col_idx in range(n_cols):
if row_mask[row_pos, col_idx]: if row_mask[row_pos, col_idx]:
out_indices[out_pos] = col_idx out_indices[dst] = col_idx
out_data[out_pos] = row_data[row_pos, col_idx] out_data[dst] = row_data[row_pos, col_idx]
out_pos += 1 dst += 1
return sp.sparse.csr_matrix( return sp.sparse.csr_matrix(
(out_data, out_indices, out_indptr), shape=(n_rows, n_cols) (out_data, out_indices, out_indptr), shape=(n_rows, n_cols)
...@@ -599,4 +621,3 @@ def numba_funcify_GetItem2ListsGrad(op, node, **kwargs): ...@@ -599,4 +621,3 @@ def numba_funcify_GetItem2ListsGrad(op, node, **kwargs):
return get_item_2lists_grad_csr(x, ind1, ind2, gz).tocsc() return get_item_2lists_grad_csr(x, ind1, ind2, gz).tocsc()
return get_item_2lists_grad_csc return get_item_2lists_grad_csc
>>>>>>> fb1d09134 (Better comments for GetItemList and GetItemListGrad)
...@@ -926,6 +926,8 @@ class GetItem2Lists(Op): ...@@ -926,6 +926,8 @@ class GetItem2Lists(Op):
assert x.format in ("csr", "csc") assert x.format in ("csr", "csc")
ind1 = ptb.as_tensor_variable(ind1) ind1 = ptb.as_tensor_variable(ind1)
ind2 = ptb.as_tensor_variable(ind2) ind2 = ptb.as_tensor_variable(ind2)
assert ind1.ndim == 1
assert ind2.ndim == 1
assert ind1.dtype in integer_dtypes assert ind1.dtype in integer_dtypes
assert ind2.dtype in integer_dtypes assert ind2.dtype in integer_dtypes
......
...@@ -513,3 +513,82 @@ def test_sparse_get_item_list_grad_wrong_index(format): ...@@ -513,3 +513,82 @@ def test_sparse_get_item_list_grad_wrong_index(format):
with pytest.raises(IndexError): with pytest.raises(IndexError):
fn(x_test, idx_test, gz_test) fn(x_test, idx_test, gz_test)
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_2lists(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
ind1 = pt.ivector("ind1")
ind2 = pt.ivector("ind2")
z = ps.get_item_2lists(x, ind1, ind2)
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
ind1_test = np.asarray([0, 0, 3, 5], dtype=np.int32)
ind2_test = np.asarray([0, 4, 2, 1], dtype=np.int32)
compare_numba_and_py_sparse([x, ind1, ind2], z, [x_test, ind1_test, ind2_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
@pytest.mark.parametrize(
("ind1_test", "ind2_test"),
[
(np.asarray([0, 6], dtype=np.int32), np.asarray([0, 3], dtype=np.int32)),
(np.asarray([0, 3], dtype=np.int32), np.asarray([0, 5], dtype=np.int32)),
],
)
def test_sparse_get_item_2lists_wrong_index(format, ind1_test, ind2_test):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
ind1 = pt.ivector("ind1")
ind2 = pt.ivector("ind2")
z = ps.get_item_2lists(x, ind1, ind2)
fn = function([x, ind1, ind2], z, mode="NUMBA")
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
with pytest.raises(IndexError):
fn(x_test, ind1_test, ind2_test)
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_2lists_grad(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
ind1 = pt.ivector("ind1")
ind2 = pt.ivector("ind2")
gz = pt.vector(name="gz", shape=(4,), dtype=config.floatX)
z = ps.get_item_2lists_grad(x, ind1, ind2, gz)
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
ind1_test = np.asarray([0, 2, 5, 2], dtype=np.int32)
ind2_test = np.asarray([1, 0, 4, 0], dtype=np.int32)
gz_test = np.asarray([0.5, -1.25, 2.0, 4.5], dtype=config.floatX)
with pytest.warns(sp.sparse.SparseEfficiencyWarning):
# GetItem2ListsGrad.perform does sparse item assignment into an initially empty
# sparse matrix, which changes sparsity structure incrementally.
compare_numba_and_py_sparse(
[x, ind1, ind2, gz], z, [x_test, ind1_test, ind2_test, gz_test]
)
@pytest.mark.parametrize("format", ("csr", "csc"))
@pytest.mark.parametrize(
("ind1_test", "ind2_test"),
[
(np.asarray([0, 6], dtype=np.int32), np.asarray([0, 3], dtype=np.int32)),
(np.asarray([0, 3], dtype=np.int32), np.asarray([0, 5], dtype=np.int32)),
],
)
def test_sparse_get_item_2lists_grad_wrong_index(format, ind1_test, ind2_test):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
ind1 = pt.ivector("ind1")
ind2 = pt.ivector("ind2")
gz = pt.vector(name="gz", shape=(2,), dtype=config.floatX)
z = ps.get_item_2lists_grad(x, ind1, ind2, gz)
fn = function([x, ind1, ind2, gz], z, mode="NUMBA")
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
gz_test = np.asarray([1.0, -2.0], dtype=config.floatX)
with pytest.raises(IndexError):
fn(x_test, ind1_test, ind2_test, gz_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论