提交 3a330c5a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add stubs for deprecated aesara.tensor rewriting modules

上级 ced64bfc
import warnings
warnings.warn(
"The module `aesara.tensor.basic_opt` is deprecated; use `aesara.tensor.rewriting.basic` instead.",
DeprecationWarning,
stacklevel=2,
)
from aesara.tensor.rewriting.basic import * # noqa: F401 E402 F403
import warnings
warnings.warn(
"The module `aesara.tensor.math_opt` is deprecated; use `aesara.tensor.rewriting.math` instead.",
DeprecationWarning,
stacklevel=2,
)
from aesara.tensor.rewriting.math import * # noqa: F401 E402 F403
import warnings
warnings.warn(
"The module `aesara.tensor.opt_uncanonicalize` is deprecated; use `aesara.tensor.rewriting.uncanonicalize` instead.",
DeprecationWarning,
stacklevel=2,
)
from aesara.tensor.rewriting.uncanonicalize import * # noqa: F401 E402 F403
import warnings
warnings.warn(
"The module `aesara.tensor.subtensor_opt` is deprecated; use `aesara.tensor.rewriting.subtensor` instead.",
DeprecationWarning,
stacklevel=2,
)
from aesara.tensor.rewriting.subtensor import * # noqa: F401 E402 F403
......@@ -3841,3 +3841,9 @@ class TestLocalElemwiseAlloc:
x_val = np.random.random((1, 5)).astype(self.dtype)
exp_res = np.broadcast_to(x_val, (5, 5))[..., None] + y_val
assert np.array_equal(func(y_val, x_val), exp_res)
def test_deprecations():
"""Make sure we can import from deprecated modules."""
with pytest.deprecated_call():
from aesara.tensor.basic_opt import register_useless # noqa: F401 F811
......@@ -4633,3 +4633,9 @@ def test_local_useless_conj():
mode_with_rewrite = default_mode.including("canonicalization", "local_useless_conj")
f = function([x], s, mode=mode_with_rewrite)
assert any(node.op == _conj for node in f.maker.fgraph.apply_nodes)
def test_deprecations():
"""Make sure we can import from deprecated modules."""
with pytest.deprecated_call():
from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811
......@@ -2154,3 +2154,9 @@ def test_local_join_subtensors(axis, slices_fn, expected_nodes):
f_val = np.concatenate([x_val[slice_val] for slice_val in slices_val], axis=axis)
np.testing.assert_array_equal(f(x_val, stop_val), f_val)
def test_deprecations():
"""Make sure we can import from deprecated modules."""
with pytest.deprecated_call():
from aesara.tensor.subtensor_opt import get_advsubtensor_axis # noqa: F401 F811
import numpy as np
import pytest
import aesara
import aesara.tensor as at
......@@ -218,3 +219,11 @@ def test_local_dimshuffle_subtensor():
assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval(
{x: np.ones((5, 1, 6, 7))}
).shape == (5, 3, 7)
def test_deprecations():
"""Make sure we can import from deprecated modules."""
with pytest.deprecated_call():
from aesara.tensor.opt_uncanonicalize import ( # noqa: F401 F811
local_reshape_dimshuffle,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论