提交 ba51e7d2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Fix PyTensor-to-Numba type resolution for sparse variables

上级 e31655fb
......@@ -25,6 +25,7 @@ from pytensor.graph.basic import Apply, NoParams
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import (
compile_function_src,
fgraph_to_python,
......@@ -32,6 +33,7 @@ from pytensor.link.utils import (
)
from pytensor.scalar.basic import ScalarType
from pytensor.scalar.math import Softplus
from pytensor.sparse import SparseTensorType
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
......@@ -105,6 +107,15 @@ def get_numba_type(
dtype = np.dtype(pytensor_type.dtype)
numba_dtype = numba.from_dtype(dtype)
return numba_dtype
elif isinstance(pytensor_type, SparseTensorType):
dtype = pytensor_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
if pytensor_type.format == "csr":
return CSRMatrixType(numba_dtype)
if pytensor_type.format == "csc":
return CSCMatrixType(numba_dtype)
raise NotImplementedError()
else:
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
......
......@@ -19,7 +19,10 @@ class CSMatrixType(types.Type):
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`."""
name: str
instance_class: type
@staticmethod
def instance_class(data, indices, indptr, shape):
raise NotImplementedError()
def __init__(self, dtype):
self.dtype = dtype
......@@ -29,6 +32,10 @@ class CSMatrixType(types.Type):
self.shape = types.UniTuple(types.int64, 2)
super().__init__(self.name)
@property
def key(self):
return (self.name, self.dtype)
make_attribute_wrapper(CSMatrixType, "data", "data")
make_attribute_wrapper(CSMatrixType, "indices", "indices")
......@@ -152,7 +159,6 @@ def overload_sparse_shape(x):
@overload_attribute(CSMatrixType, "ndim")
def overload_sparse_ndim(inst):
if not isinstance(inst, CSMatrixType):
return
......
import numba
import numpy as np
import pytest
import scipy as sp
# Load Numba customizations
# Make sure the Numba customizations are loaded
import pytensor.link.numba.dispatch.sparse # noqa: F401
from pytensor import config
from pytensor.sparse import Dot, SparseTensorType
from tests.link.numba.test_basic import compare_numba_and_py
pytestmark = pytest.mark.filterwarnings("error")
def test_sparse_unboxing():
......@@ -62,3 +69,19 @@ def test_sparse_ndim():
res = test_fn(x_val)
assert res == 2
def test_sparse_objmode():
x = SparseTensorType("csc", dtype=config.floatX)()
y = SparseTensorType("csc", dtype=config.floatX)()
out = Dot()(x, y)
x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
with pytest.warns(
UserWarning,
match="Numba will use object mode to run SparseDot's perform method",
):
compare_numba_and_py(((x, y), (out,)), [x_val, y_val])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论