提交 8616d398 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename aesara.sparse.opt to aesara.sparse.rewriting

上级 f5af3cb4
......@@ -13,7 +13,7 @@ from aesara.sparse.type import SparseTensorType, _is_sparse
if enable_sparse:
from aesara.sparse import opt, sharedvar
from aesara.sparse import rewriting, sharedvar
from aesara.sparse.basic import *
from aesara.sparse.sharedvar import sparse_constructor as shared
......
差异被折叠。
差异被折叠。
......@@ -84,7 +84,12 @@ from aesara.sparse.basic import (
_is_sparse_variable,
_mtypes,
)
from aesara.sparse.opt import CSMGradC, StructuredDotCSC, UsmmCscDense
from aesara.sparse.rewriting import (
AddSD_ccode,
CSMGradC,
StructuredDotCSC,
UsmmCscDense,
)
from aesara.tensor.basic import MakeVector
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import sum as at_sum
......@@ -491,7 +496,7 @@ class TestSparseInferShape(utt.InferShapeTester):
sp.sparse.csr_matrix(random_lil((10, 40), config.floatX, 3)),
np.random.standard_normal((10, 40)).astype(config.floatX),
],
(AddSD, sparse.opt.AddSD_ccode),
(AddSD, AddSD_ccode),
)
def test_mul_ss(self):
......
import pytest
sp = pytest.importorskip("scipy", minversion="0.7.0")
import numpy as np
import pytest
import scipy as sp
import aesara
from aesara import sparse
from aesara.compile.mode import Mode, get_default_mode
from aesara.configdefaults import config
from aesara.sparse.rewriting import SamplingDotCSR, sd_csc
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import sum as at_sum
from aesara.tensor.type import ivector, matrix, vector
......@@ -38,7 +36,7 @@ def test_local_csm_properties_csm():
f(v.data, v.indices, v.indptr, v.shape)
@pytest.mark.skip(reason="Opt disabled as it don't support unsorted indices")
@pytest.mark.skip(reason="Rewrite disabled as it don't support unsorted indices")
@pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
)
......@@ -143,7 +141,7 @@ def test_local_sampling_dot_csr():
# SamplingDotCSR's C implementation needs blas, so it should not
# be inserted
assert not any(
isinstance(node.op, sparse.opt.SamplingDotCSR)
isinstance(node.op, SamplingDotCSR)
for node in f.maker.fgraph.toposort()
)
......@@ -174,6 +172,6 @@ def test_sd_csc():
nrows = as_tensor_variable(np.int32(A.shape[0]))
b = as_tensor_variable(b)
res = aesara.sparse.opt.sd_csc(a_val, a_ind, a_ptr, nrows, b).eval()
res = sd_csc(a_val, a_ind, a_ptr, nrows, b).eval()
utt.assert_allclose(res, target)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论