提交 e31655fb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Add np.shape and ndim overloads for sparse Numba types

上级 4317d0df
import numpy as np
import scipy as sp import scipy as sp
import scipy.sparse import scipy.sparse
from numba.core import cgutils, types from numba.core import cgutils, types
...@@ -6,6 +7,8 @@ from numba.extending import ( ...@@ -6,6 +7,8 @@ from numba.extending import (
box, box,
make_attribute_wrapper, make_attribute_wrapper,
models, models,
overload,
overload_attribute,
register_model, register_model,
typeof_impl, typeof_impl,
unbox, unbox,
...@@ -139,3 +142,21 @@ def box_matrix(typ, val, c): ...@@ -139,3 +142,21 @@ def box_matrix(typ, val, c):
c.pyapi.decref(shape_obj) c.pyapi.decref(shape_obj)
return obj return obj
@overload(np.shape)
def overload_sparse_shape(x):
if isinstance(x, CSMatrixType):
return lambda x: x.shape
@overload_attribute(CSMatrixType, "ndim")
def overload_sparse_ndim(inst):
if not isinstance(inst, CSMatrixType):
return
def ndim(inst):
return 2
return ndim
...@@ -38,3 +38,27 @@ def test_sparse_boxing(): ...@@ -38,3 +38,27 @@ def test_sparse_boxing():
assert np.array_equal(res_y_val.indices, y_val.indices) assert np.array_equal(res_y_val.indices, y_val.indices)
assert np.array_equal(res_y_val.indptr, y_val.indptr) assert np.array_equal(res_y_val.indptr, y_val.indptr)
assert res_y_val.shape == y_val.shape assert res_y_val.shape == y_val.shape
def test_sparse_shape():
@numba.njit
def test_fn(x):
return np.shape(x)
x_val = sp.sparse.csr_matrix(np.eye(100))
res = test_fn(x_val)
assert res == (100, 100)
def test_sparse_ndim():
@numba.njit
def test_fn(x):
return x.ndim
x_val = sp.sparse.csr_matrix(np.eye(100))
res = test_fn(x_val)
assert res == 2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论