提交 ba51e7d2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Fix PyTensor-to-Numba type resolution for sparse variables

上级 e31655fb
...@@ -25,6 +25,7 @@ from pytensor.graph.basic import Apply, NoParams ...@@ -25,6 +25,7 @@ from pytensor.graph.basic import Apply, NoParams
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.ifelse import IfElse from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import ( from pytensor.link.utils import (
compile_function_src, compile_function_src,
fgraph_to_python, fgraph_to_python,
...@@ -32,6 +33,7 @@ from pytensor.link.utils import ( ...@@ -32,6 +33,7 @@ from pytensor.link.utils import (
) )
from pytensor.scalar.basic import ScalarType from pytensor.scalar.basic import ScalarType
from pytensor.scalar.math import Softplus from pytensor.scalar.math import Softplus
from pytensor.sparse import SparseTensorType
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
...@@ -105,6 +107,15 @@ def get_numba_type( ...@@ -105,6 +107,15 @@ def get_numba_type(
dtype = np.dtype(pytensor_type.dtype) dtype = np.dtype(pytensor_type.dtype)
numba_dtype = numba.from_dtype(dtype) numba_dtype = numba.from_dtype(dtype)
return numba_dtype return numba_dtype
elif isinstance(pytensor_type, SparseTensorType):
dtype = pytensor_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
if pytensor_type.format == "csr":
return CSRMatrixType(numba_dtype)
if pytensor_type.format == "csc":
return CSCMatrixType(numba_dtype)
raise NotImplementedError()
else: else:
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
......
...@@ -19,7 +19,10 @@ class CSMatrixType(types.Type): ...@@ -19,7 +19,10 @@ class CSMatrixType(types.Type):
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.""" """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`."""
name: str name: str
instance_class: type
@staticmethod
def instance_class(data, indices, indptr, shape):
raise NotImplementedError()
def __init__(self, dtype): def __init__(self, dtype):
self.dtype = dtype self.dtype = dtype
...@@ -29,6 +32,10 @@ class CSMatrixType(types.Type): ...@@ -29,6 +32,10 @@ class CSMatrixType(types.Type):
self.shape = types.UniTuple(types.int64, 2) self.shape = types.UniTuple(types.int64, 2)
super().__init__(self.name) super().__init__(self.name)
@property
def key(self):
return (self.name, self.dtype)
make_attribute_wrapper(CSMatrixType, "data", "data") make_attribute_wrapper(CSMatrixType, "data", "data")
make_attribute_wrapper(CSMatrixType, "indices", "indices") make_attribute_wrapper(CSMatrixType, "indices", "indices")
...@@ -152,7 +159,6 @@ def overload_sparse_shape(x): ...@@ -152,7 +159,6 @@ def overload_sparse_shape(x):
@overload_attribute(CSMatrixType, "ndim") @overload_attribute(CSMatrixType, "ndim")
def overload_sparse_ndim(inst): def overload_sparse_ndim(inst):
if not isinstance(inst, CSMatrixType): if not isinstance(inst, CSMatrixType):
return return
......
import numba import numba
import numpy as np import numpy as np
import pytest
import scipy as sp import scipy as sp
# Load Numba customizations # Make sure the Numba customizations are loaded
import pytensor.link.numba.dispatch.sparse # noqa: F401 import pytensor.link.numba.dispatch.sparse # noqa: F401
from pytensor import config
from pytensor.sparse import Dot, SparseTensorType
from tests.link.numba.test_basic import compare_numba_and_py
pytestmark = pytest.mark.filterwarnings("error")
def test_sparse_unboxing(): def test_sparse_unboxing():
...@@ -62,3 +69,19 @@ def test_sparse_ndim(): ...@@ -62,3 +69,19 @@ def test_sparse_ndim():
res = test_fn(x_val) res = test_fn(x_val)
assert res == 2 assert res == 2
def test_sparse_objmode():
x = SparseTensorType("csc", dtype=config.floatX)()
y = SparseTensorType("csc", dtype=config.floatX)()
out = Dot()(x, y)
x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
with pytest.warns(
UserWarning,
match="Numba will use object mode to run SparseDot's perform method",
):
compare_numba_and_py(((x, y), (out,)), [x_val, y_val])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论