提交 d3ff35dc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow conversion of dense to sparse types in SparseTensorType.convert_variable

上级 a5f8b693
......@@ -4,7 +4,7 @@ import scipy.sparse
import aesara
from aesara import scalar as aes
from aesara.graph.type import HasDataType
from aesara.tensor.type import TensorType
from aesara.tensor.type import DenseTensorType, TensorType
def _is_sparse(x):
......@@ -154,11 +154,25 @@ class SparseTensorType(TensorType, HasDataType):
def convert_variable(self, var):
res = super().convert_variable(var)
if res and not isinstance(res.type, type(self)):
# TODO: Convert to this sparse format
raise NotImplementedError()
if res is None:
return res
if not isinstance(res.type, type(self)):
if isinstance(res.type, DenseTensorType):
if self.format == "csr":
from aesara.sparse.basic import csr_from_dense
return csr_from_dense(res)
else:
from aesara.sparse.basic import csc_from_dense
return csc_from_dense(res)
return None
if res.format != self.format:
# TODO: Convert sparse `var`s with different formats to this format?
return None
return res
......
import pytest
from aesara.sparse import matrix as sp_matrix
from aesara.sparse.type import SparseTensorType
from aesara.tensor import dmatrix
......@@ -16,13 +14,16 @@ def test_Sparse_convert_variable():
z = sp_matrix("csr", dtype="float64", name="z")
assert y.type.convert_variable(z) is None
assert z.type.convert_variable(y) is None
res = y.type.convert_variable(x)
assert res.type == y.type
res = z.type.convert_variable(x)
assert res.type == z.type
# TODO FIXME: This is a questionable result, because `x.type` is associated
# with a dense `Type`, but, since `TensorType` is a base class of `Sparse`,
# we would need to added sparse/dense logic to `TensorType`, and we don't
# want to do that.
assert x.type.convert_variable(y) is y
# TODO FIXME: We should be able to do this.
with pytest.raises(NotImplementedError):
y.type.convert_variable(x)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论