提交 54db5ad0 authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement SparseFromDense in numba backend

上级 bad5f528
import numpy as np
import scipy as sp
from numba.core import types
from numba.extending import overload
from pytensor import config
......@@ -9,8 +10,19 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.sparse.variable import CSMatrixType
from pytensor.sparse import CSM, Cast, CSMProperties, DenseFromSparse, Transpose
from pytensor.link.numba.dispatch.sparse.variable import (
CSMatrixType,
csc_matrix_from_components,
csr_matrix_from_components,
)
from pytensor.sparse import (
CSM,
Cast,
CSMProperties,
DenseFromSparse,
SparseFromDense,
Transpose,
)
@overload(numba_deepcopy)
......@@ -84,3 +96,23 @@ def numba_funcify_DenseFromSparse(op, node, **kwargs):
return x.toarray()
return to_array
@register_funcify_default_op_cache_key(SparseFromDense)
def numba_funcify_SparseFromDense(op, node, **kwargs):
sparse_format = op.format
if sparse_format == "csr":
@numba_basic.numba_njit
def dense_to_csr(matrix):
return sp.sparse.csr_matrix(matrix)
return dense_to_csr
else:
@numba_basic.numba_njit
def dense_to_csc(matrix):
return sp.sparse.csc_matrix(matrix)
return dense_to_csc
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论