提交 fbd7a597 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba sparse: Implement basic Ops

上级 916da3f0
from pytensor.link.numba.dispatch.sparse import variable
from pytensor.link.numba.dispatch.sparse import basic, variable
import numpy as np
import scipy as sp
from numba.extending import overload
from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.sparse.variable import CSMatrixType
from pytensor.sparse import CSM, Cast, CSMProperties
@overload(numba_deepcopy)
def numba_deepcopy_sparse(x):
if isinstance(x, CSMatrixType):
def sparse_deepcopy(x):
return x.copy()
return sparse_deepcopy
@register_funcify_default_op_cache_key(CSMProperties)
def numba_funcify_CSMProperties(op, node, **kwargs):
@numba_basic.numba_njit
def csm_properties(x):
# Reconsider this int32/int64. Scipy/base PyTensor use int32 for indices/indptr.
# But this seems to be legacy mistake and devs would choose int64 nowadays, and may move there.
return x.data, x.indices, x.indptr, np.asarray(x.shape, dtype="int32")
return csm_properties
@register_funcify_default_op_cache_key(CSM)
def numba_funcify_CSM(op, node, **kwargs):
format = op.format
@numba_basic.numba_njit
def csm_constructor(data, indices, indptr, shape):
constructor_arg = (data, indices, indptr)
shape_arg = (shape[0], shape[1])
if format == "csr":
return sp.sparse.csr_matrix(constructor_arg, shape=shape_arg)
else:
return sp.sparse.csc_matrix(constructor_arg, shape=shape_arg)
return csm_constructor
@register_funcify_default_op_cache_key(Cast)
def numba_funcify_Cast(op, node, **kwargs):
inp_dtype = node.inputs[0].type.dtype
out_dtype = np.dtype(op.out_type)
if not np.can_cast(inp_dtype, out_dtype):
if config.compiler_verbose:
print( # noqa: T201
f"Sparse Cast fallback to obj mode due to unsafe casting from {inp_dtype} to {out_dtype}"
)
return generate_fallback_impl(op, node, **kwargs)
@numba_basic.numba_njit
def cast(x):
return x.astype(out_dtype)
return cast
......@@ -9,6 +9,7 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.sparse.variable import CSCMatrixType, CSRMatrixType
from pytensor.tensor.type_other import SliceType
from pytensor.typed_list import (
Append,
......@@ -64,6 +65,18 @@ def list_all_equal(x, y):
def all_equal(x, y):
return x == y
if (isinstance(x, CSRMatrixType) and isinstance(y, CSRMatrixType)) or (
isinstance(x, CSCMatrixType) and isinstance(y, CSCMatrixType)
):
def all_equal(x, y):
return (
x.shape == y.shape
and (x.data == y.data).all()
and (x.indptr == y.indptr).all()
and (x.indices == y.indices).all()
)
return all_equal
......
......@@ -255,10 +255,18 @@ def test_simple_graph(format):
y_test = rng.normal(size=(3,))
with pytest.warns(
UserWarning, match=r"Numba will use object mode to run .* perform method"
UserWarning,
match=r"Numba will use object mode to run SparseDenseVectorMultiply's perform method",
):
compare_numba_and_py_sparse(
[x, y],
z,
[x_test, y_test],
)
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_deepcopy(format):
x = ps.matrix(shape=(3, 3), format=format)
x_test = sp.sparse.random(3, 3, density=0.5, format=format)
compare_numba_and_py_sparse([x], [x], [x_test])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论