提交 e31655fb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Add np.shape and ndim overloads for sparse Numba types

上级 4317d0df
import numpy as np
import scipy as sp
import scipy.sparse
from numba.core import cgutils, types
......@@ -6,6 +7,8 @@ from numba.extending import (
box,
make_attribute_wrapper,
models,
overload,
overload_attribute,
register_model,
typeof_impl,
unbox,
......@@ -139,3 +142,21 @@ def box_matrix(typ, val, c):
c.pyapi.decref(shape_obj)
return obj
@overload(np.shape)
def overload_sparse_shape(x):
if isinstance(x, CSMatrixType):
return lambda x: x.shape
@overload_attribute(CSMatrixType, "ndim")
def overload_sparse_ndim(inst):
if not isinstance(inst, CSMatrixType):
return
def ndim(inst):
return 2
return ndim
......@@ -38,3 +38,27 @@ def test_sparse_boxing():
assert np.array_equal(res_y_val.indices, y_val.indices)
assert np.array_equal(res_y_val.indptr, y_val.indptr)
assert res_y_val.shape == y_val.shape
def test_sparse_shape():
@numba.njit
def test_fn(x):
return np.shape(x)
x_val = sp.sparse.csr_matrix(np.eye(100))
res = test_fn(x_val)
assert res == (100, 100)
def test_sparse_ndim():
@numba.njit
def test_fn(x):
return x.ndim
x_val = sp.sparse.csr_matrix(np.eye(100))
res = test_fn(x_val)
assert res == 2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论