提交 8528590d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove SciPy check from aesara.sparse

上级 5043ac8e
from warnings import warn from aesara.sparse import rewriting, sharedvar
from aesara.sparse.basic import *
from aesara.sparse.sharedvar import sparse_constructor as shared
try:
import scipy
enable_sparse = True
except ImportError:
enable_sparse = False
warn("SciPy can't be imported. Sparse matrix support is disabled.")
from aesara.sparse.type import SparseTensorType, _is_sparse from aesara.sparse.type import SparseTensorType, _is_sparse
if enable_sparse: def sparse_grad(var):
from aesara.sparse import rewriting, sharedvar """This function return a new variable whose gradient will be
from aesara.sparse.basic import * stored in a sparse format instead of dense.
from aesara.sparse.sharedvar import sparse_constructor as shared
def sparse_grad(var):
"""This function return a new variable whose gradient will be
stored in a sparse format instead of dense.
Currently only variable created by AdvancedSubtensor1 is supported. Currently only variable created by AdvancedSubtensor1 is supported.
i.e. a_tensor_var[an_int_vector]. i.e. a_tensor_var[an_int_vector].
.. versionadded:: 0.6rc4 .. versionadded:: 0.6rc4
""" """
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1 from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1
if var.owner is None or not isinstance( if var.owner is None or not isinstance(
var.owner.op, (AdvancedSubtensor, AdvancedSubtensor1) var.owner.op, (AdvancedSubtensor, AdvancedSubtensor1)
): ):
raise TypeError( raise TypeError(
"Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1" "Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1"
) )
x = var.owner.inputs[0] x = var.owner.inputs[0]
indices = var.owner.inputs[1:] indices = var.owner.inputs[1:]
if len(indices) > 1: if len(indices) > 1:
raise TypeError( raise TypeError(
"Sparse gradient is only implemented for single advanced indexing" "Sparse gradient is only implemented for single advanced indexing"
) )
ret = AdvancedSubtensor1(sparse_grad=True)(x, indices[0]) ret = AdvancedSubtensor1(sparse_grad=True)(x, indices[0])
return ret return ret
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论