提交 8528590d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove SciPy check from aesara.sparse

上级 5043ac8e
from warnings import warn
try:
import scipy
enable_sparse = True
except ImportError:
enable_sparse = False
warn("SciPy can't be imported. Sparse matrix support is disabled.")
from aesara.sparse import rewriting, sharedvar
from aesara.sparse.basic import *
from aesara.sparse.sharedvar import sparse_constructor as shared
from aesara.sparse.type import SparseTensorType, _is_sparse
if enable_sparse:
from aesara.sparse import rewriting, sharedvar
from aesara.sparse.basic import *
from aesara.sparse.sharedvar import sparse_constructor as shared
def sparse_grad(var):
"""This function return a new variable whose gradient will be
stored in a sparse format instead of dense.
def sparse_grad(var):
"""This function return a new variable whose gradient will be
stored in a sparse format instead of dense.
Currently only variable created by AdvancedSubtensor1 is supported.
i.e. a_tensor_var[an_int_vector].
Currently only variable created by AdvancedSubtensor1 is supported.
i.e. a_tensor_var[an_int_vector].
.. versionadded:: 0.6rc4
"""
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1
.. versionadded:: 0.6rc4
"""
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1
if var.owner is None or not isinstance(
var.owner.op, (AdvancedSubtensor, AdvancedSubtensor1)
):
raise TypeError(
"Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1"
)
if var.owner is None or not isinstance(
var.owner.op, (AdvancedSubtensor, AdvancedSubtensor1)
):
raise TypeError(
"Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1"
)
x = var.owner.inputs[0]
indices = var.owner.inputs[1:]
x = var.owner.inputs[0]
indices = var.owner.inputs[1:]
if len(indices) > 1:
raise TypeError(
"Sparse gradient is only implemented for single advanced indexing"
)
if len(indices) > 1:
raise TypeError(
"Sparse gradient is only implemented for single advanced indexing"
)
ret = AdvancedSubtensor1(sparse_grad=True)(x, indices[0])
return ret
ret = AdvancedSubtensor1(sparse_grad=True)(x, indices[0])
return ret
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论