提交 54db5ad0 authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement SparseFromDense in numba backend

上级 bad5f528
import numpy as np import numpy as np
import scipy as sp import scipy as sp
from numba.core import types
from numba.extending import overload from numba.extending import overload
from pytensor import config from pytensor import config
...@@ -9,8 +10,19 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -9,8 +10,19 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_default_op_cache_key, register_funcify_default_op_cache_key,
) )
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.sparse.variable import CSMatrixType from pytensor.link.numba.dispatch.sparse.variable import (
from pytensor.sparse import CSM, Cast, CSMProperties, DenseFromSparse, Transpose CSMatrixType,
csc_matrix_from_components,
csr_matrix_from_components,
)
from pytensor.sparse import (
CSM,
Cast,
CSMProperties,
DenseFromSparse,
SparseFromDense,
Transpose,
)
@overload(numba_deepcopy) @overload(numba_deepcopy)
...@@ -84,3 +96,23 @@ def numba_funcify_DenseFromSparse(op, node, **kwargs): ...@@ -84,3 +96,23 @@ def numba_funcify_DenseFromSparse(op, node, **kwargs):
return x.toarray() return x.toarray()
return to_array return to_array
@register_funcify_default_op_cache_key(SparseFromDense)
def numba_funcify_SparseFromDense(op, node, **kwargs):
sparse_format = op.format
if sparse_format == "csr":
@numba_basic.numba_njit
def dense_to_csr(matrix):
return sp.sparse.csr_matrix(matrix)
return dense_to_csr
else:
@numba_basic.numba_njit
def dense_to_csc(matrix):
return sp.sparse.csc_matrix(matrix)
return dense_to_csc
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论