提交 220442a2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Implement a copy method for Numba sparse types

上级 eb6fc66e
......@@ -2,13 +2,16 @@ import numpy as np
import scipy as sp
import scipy.sparse
from numba.core import cgutils, types
from numba.core.imputils import impl_ret_borrowed
from numba.extending import (
NativeValue,
box,
intrinsic,
make_attribute_wrapper,
models,
overload,
overload_attribute,
overload_method,
register_model,
typeof_impl,
unbox,
......@@ -166,3 +169,38 @@ def overload_sparse_ndim(inst):
return 2
return ndim
@intrinsic
def _sparse_copy(typingctx, inst, data, indices, indptr, shape):
def _construct(context, builder, sig, args):
typ = sig.return_type
struct = cgutils.create_struct_proxy(typ)(context, builder)
_, data, indices, indptr, shape = args
struct.data = data
struct.indices = indices
struct.indptr = indptr
struct.shape = shape
return impl_ret_borrowed(
context,
builder,
sig.return_type,
struct._getvalue(),
)
sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape)
return sig, _construct
@overload_method(CSMatrixType, "copy")
def overload_sparse_copy(inst):
if not isinstance(inst, CSMatrixType):
return
def copy(inst):
return _sparse_copy(
inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape
)
return copy
......@@ -71,6 +71,19 @@ def test_sparse_ndim():
assert res == 2
def test_sparse_copy():
@numba.njit
def test_fn(x):
y = x.copy()
return (
y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices)
)
x_val = sp.sparse.csr_matrix(np.eye(100))
assert test_fn(x_val)
def test_sparse_objmode():
x = SparseTensorType("csc", dtype=config.floatX)()
y = SparseTensorType("csc", dtype=config.floatX)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论