提交 d3ff35dc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow conversion of dense to sparse types in SparseTensorType.convert_variable

上级 a5f8b693
...@@ -4,7 +4,7 @@ import scipy.sparse ...@@ -4,7 +4,7 @@ import scipy.sparse
import aesara import aesara
from aesara import scalar as aes from aesara import scalar as aes
from aesara.graph.type import HasDataType from aesara.graph.type import HasDataType
from aesara.tensor.type import TensorType from aesara.tensor.type import DenseTensorType, TensorType
def _is_sparse(x): def _is_sparse(x):
...@@ -154,11 +154,25 @@ class SparseTensorType(TensorType, HasDataType): ...@@ -154,11 +154,25 @@ class SparseTensorType(TensorType, HasDataType):
def convert_variable(self, var): def convert_variable(self, var):
res = super().convert_variable(var) res = super().convert_variable(var)
if res and not isinstance(res.type, type(self)): if res is None:
# TODO: Convert to this sparse format return res
raise NotImplementedError()
if not isinstance(res.type, type(self)):
if isinstance(res.type, DenseTensorType):
if self.format == "csr":
from aesara.sparse.basic import csr_from_dense
return csr_from_dense(res)
else:
from aesara.sparse.basic import csc_from_dense
return csc_from_dense(res)
return None
if res.format != self.format:
# TODO: Convert sparse `var`s with different formats to this format? # TODO: Convert sparse `var`s with different formats to this format?
return None
return res return res
......
import pytest
from aesara.sparse import matrix as sp_matrix from aesara.sparse import matrix as sp_matrix
from aesara.sparse.type import SparseTensorType from aesara.sparse.type import SparseTensorType
from aesara.tensor import dmatrix from aesara.tensor import dmatrix
...@@ -16,13 +14,16 @@ def test_Sparse_convert_variable(): ...@@ -16,13 +14,16 @@ def test_Sparse_convert_variable():
z = sp_matrix("csr", dtype="float64", name="z") z = sp_matrix("csr", dtype="float64", name="z")
assert y.type.convert_variable(z) is None assert y.type.convert_variable(z) is None
assert z.type.convert_variable(y) is None
res = y.type.convert_variable(x)
assert res.type == y.type
res = z.type.convert_variable(x)
assert res.type == z.type
# TODO FIXME: This is a questionable result, because `x.type` is associated # TODO FIXME: This is a questionable result, because `x.type` is associated
# with a dense `Type`, but, since `TensorType` is a base class of `Sparse`, # with a dense `Type`, but, since `TensorType` is a base class of `Sparse`,
# we would need to added sparse/dense logic to `TensorType`, and we don't # we would need to added sparse/dense logic to `TensorType`, and we don't
# want to do that. # want to do that.
assert x.type.convert_variable(y) is y assert x.type.convert_variable(y) is y
# TODO FIXME: We should be able to do this.
with pytest.raises(NotImplementedError):
y.type.convert_variable(x)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论