提交 96253d5d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make as_symbolic work with sparse matrices

上级 a1739f6c
......@@ -14,6 +14,7 @@ import scipy.sparse
from numpy.lib.stride_tricks import as_strided
import aesara
from aesara import _as_symbolic, as_symbolic
from aesara import scalar as aes
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined
......@@ -128,6 +129,11 @@ def _is_dense(x):
return isinstance(x, np.ndarray)
@_as_symbolic.register(scipy.sparse.base.spmatrix)
def as_symbolic_sparse(x, **kwargs):
return as_sparse_variable(x, **kwargs)
def as_sparse_variable(x, name=None, ndim=None, **kwargs):
"""
Wrapper around SparseVariable constructor to construct
......@@ -174,26 +180,7 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):
as_sparse = as_sparse_variable
def as_sparse_or_tensor_variable(x, name=None):
"""
Same as `as_sparse_variable` but if we can't make a
sparse variable, we try to make a tensor variable.
Parameters
----------
x
A sparse matrix.
Returns
-------
SparseVariable or TensorVariable version of `x`
"""
try:
return as_sparse_variable(x, name)
except (ValueError, TypeError):
return at.as_tensor_variable(x, name)
as_sparse_or_tensor_variable = as_symbolic
def constant(x, name=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论