提交 6c949b3a authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement GetItemList and GetItemListGrad sparse Ops in Numba backend

上级 7ad266d6
......@@ -454,3 +454,62 @@ def test_sparse_row_scale(format):
v_test = np.random.random(7).astype(config.floatX)
compare_numba_and_py_sparse([x, v], z, [x_test, v_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_list(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
idx = pt.ivector("idx")
z = ps.get_item_list(x, idx)
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
idx_test = np.asarray([0, 2, 5, 2], dtype=np.int32)
compare_numba_and_py_sparse([x, idx], z, [x_test, idx_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_list_wrong_index(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
idx = pt.ivector("idx")
z = ps.get_item_list(x, idx)
fn = function([x, idx], z, mode="NUMBA")
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
idx_test = np.asarray([0, 6], dtype=np.int32)
with pytest.raises(IndexError):
fn(x_test, idx_test)
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_list_grad(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
idx = pt.ivector("idx")
gz = ps.matrix(format, name="gz", shape=(4, 5), dtype=config.floatX)
z = ps.get_item_list_grad(x, idx, gz)
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
gz_test = sp.sparse.random(4, 5, density=0.4, format=format, dtype=config.floatX)
idx_test = np.asarray([0, 2, 5, 2], dtype=np.int32)
with pytest.warns(sp.sparse.SparseEfficiencyWarning):
# GetItemListGrad.perform does sparse row assignment into an initially empty sparse
# matrix, which changes sparsity structure incrementally and triggers the warning.
compare_numba_and_py_sparse([x, idx, gz], z, [x_test, idx_test, gz_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_get_item_list_grad_wrong_index(format):
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
idx = pt.ivector("idx")
gz = ps.matrix(format, name="gz", shape=(2, 5), dtype=config.floatX)
z = ps.get_item_list_grad(x, idx, gz)
fn = function([x, idx, gz], z, mode="NUMBA")
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
gz_test = sp.sparse.random(2, 5, density=0.4, format=format, dtype=config.floatX)
idx_test = np.asarray([0, 6], dtype=np.int32)
with pytest.raises(IndexError):
fn(x_test, idx_test, gz_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论