提交 57acc845 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add stubs for deprecated aesara.graph rewriting modules

上级 75506689
import warnings
warnings.warn(
"The module `aesara.graph.kanren` is deprecated; use `aesara.graph.rewriting.kanren` instead.",
DeprecationWarning,
stacklevel=2,
)
from aesara.graph.rewriting.kanren import * # noqa: F401 E402 F403
import warnings
warnings.warn(
"The module `aesara.graph.opt` is deprecated; use `aesara.graph.rewriting.basic` instead.",
DeprecationWarning,
stacklevel=2,
)
from aesara.graph.rewriting.basic import * # noqa: F401 E402 F403
from aesara.graph.rewriting.basic import DEPRECATED_NAMES # noqa: F401 E402 F403
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
global DEPRECATED_NAMES
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
import warnings
warnings.warn(
"The module `aesara.graph.opt_utils` is deprecated; use `aesara.graph.rewriting.utils` instead.",
DeprecationWarning,
stacklevel=2,
)
from aesara.graph.rewriting.utils import * # noqa: F401 E402 F403
from aesara.graph.rewriting.utils import DEPRECATED_NAMES # noqa: F401 E402 F403
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
global DEPRECATED_NAMES
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
import warnings
warnings.warn(
"The module `aesara.graph.optdb` is deprecated; use `aesara.graph.rewriting.db` instead.",
DeprecationWarning,
stacklevel=2,
)
from aesara.graph.rewriting.db import * # noqa: F401 E402 F403
from aesara.graph.rewriting.db import DEPRECATED_NAMES # noqa: F401 E402 F403
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
global DEPRECATED_NAMES
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
import warnings
warnings.warn(
"The module `aesara.graph.unify` is deprecated; use `aesara.graph.rewriting.unify` instead.",
DeprecationWarning,
stacklevel=2,
)
from aesara.graph.rewriting.unify import * # noqa: F401 E402 F403
import sys
import pytest import pytest
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -830,3 +832,17 @@ def test_OpToRewriterTracker(): ...@@ -830,3 +832,17 @@ def test_OpToRewriterTracker():
local_rewriter_2, local_rewriter_2,
local_rewriter_1, local_rewriter_1,
] ]
def test_deprecations():
"""Make sure we can import deprecated classes from current and deprecated modules."""
with pytest.deprecated_call():
from aesara.graph.rewriting.basic import GlobalOptimizer
with pytest.deprecated_call():
from aesara.graph.opt import GlobalOptimizer, LocalOptimizer # noqa: F401 F811
del sys.modules["aesara.graph.opt"]
with pytest.deprecated_call():
from aesara.graph.opt import GraphRewriter # noqa: F401
import sys
import pytest import pytest
from aesara.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter from aesara.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter
...@@ -84,3 +86,17 @@ class TestDB: ...@@ -84,3 +86,17 @@ class TestDB:
def test_ProxyDB(self): def test_ProxyDB(self):
with pytest.raises(TypeError, match=r"`db` must be.*"): with pytest.raises(TypeError, match=r"`db` must be.*"):
ProxyDB(object()) ProxyDB(object())
def test_deprecations():
"""Make sure we can import deprecated classes from current and deprecated modules."""
with pytest.deprecated_call():
from aesara.graph.rewriting.db import OptimizationDatabase # noqa: F401 F811
with pytest.deprecated_call():
from aesara.graph.optdb import OptimizationDatabase # noqa: F401 F811
del sys.modules["aesara.graph.optdb"]
with pytest.deprecated_call():
from aesara.graph.optdb import RewriteDatabase # noqa: F401
...@@ -165,3 +165,9 @@ def test_KanrenRelationSub_dot(): ...@@ -165,3 +165,9 @@ def test_KanrenRelationSub_dot():
assert expr_opt.owner.inputs[1].owner.op == at.add assert expr_opt.owner.inputs[1].owner.op == at.add
assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot) assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot)
assert isinstance(expr_opt.owner.inputs[1].owner.inputs[1].owner.op, Dot) assert isinstance(expr_opt.owner.inputs[1].owner.inputs[1].owner.op, Dot)
def test_deprecations():
"""Make sure we can import deprecated classes from current and deprecated modules."""
with pytest.deprecated_call():
from aesara.graph.kanren import KanrenRelationSub # noqa: F401 F811
...@@ -350,3 +350,9 @@ def test_convert_strs_to_vars(): ...@@ -350,3 +350,9 @@ def test_convert_strs_to_vars():
res = convert_strs_to_vars((val,)) res = convert_strs_to_vars((val,))
assert isinstance(res[0], Constant) assert isinstance(res[0], Constant)
assert np.array_equal(res[0].data, val) assert np.array_equal(res[0].data, val)
def test_deprecations():
"""Make sure we can import deprecated classes from current and deprecated modules."""
with pytest.deprecated_call():
from aesara.graph.unify import eval_if_etuple # noqa: F401 F811
import sys
import pytest
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import graph_rewriter from aesara.graph.rewriting.basic import graph_rewriter
from aesara.graph.rewriting.utils import is_same_graph, rewrite_graph from aesara.graph.rewriting.utils import is_same_graph, rewrite_graph
...@@ -156,3 +160,17 @@ def test_rewrite_graph(): ...@@ -156,3 +160,17 @@ def test_rewrite_graph():
) )
assert x_rewritten.outputs[0] is y assert x_rewritten.outputs[0] is y
def test_deprecations():
"""Make sure we can import deprecated classes from current and deprecated modules."""
with pytest.deprecated_call():
from aesara.graph.rewriting.utils import optimize_graph # noqa: F401 F811
with pytest.deprecated_call():
from aesara.graph.opt_utils import optimize_graph # noqa: F401 F811
del sys.modules["aesara.graph.opt_utils"]
with pytest.deprecated_call():
from aesara.graph.opt_utils import rewrite_graph # noqa: F401
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论