提交 c736927b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up deprecations

This commit introduces the use of module-level `__getattr__` overrides to emit deprecation warnings for renamed objects. It also adds some missing `pytest.deprecated_call` checks.
上级 e0d91807
......@@ -62,12 +62,6 @@ for p in sys.path:
raise RuntimeError("You have the aesara directory in your Python path.")
from aesara.configdefaults import config
from aesara.utils import deprecated
change_flags = deprecated("Use aesara.config.change_flags instead!")(
config.change_flags
)
# This is the api version for ops that generate C code. External ops
......@@ -178,3 +172,27 @@ from aesara.scan.views import foldl, foldr, map, reduce
# imports were executed, we can warn about remaining flags provided by the user
# through AESARA_FLAGS.
config.warn_unused_flags()
DEPRECATED_NAMES = [
(
"change_flags",
"`aesara.change_flags` is deprecated: use `aesara.config.change_flags` instead.",
config.change_flags,
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
......@@ -14,7 +14,7 @@ from functools import wraps
from io import StringIO
from typing import Callable, Dict, Optional, Sequence, Union
from aesara.utils import deprecated, hash_from_code
from aesara.utils import hash_from_code
_logger = logging.getLogger("aesara.configparser")
......@@ -582,8 +582,7 @@ class _ConfigProxy:
if attr == "_actual":
return _ConfigProxy._actual
warnings.warn(
"Accessing config through `aesara.configparser.config` is deprecated. "
"Use `aesara.config` instead.",
"`aesara.configparser.config` is deprecated; use `aesara.config` instead.",
DeprecationWarning,
stacklevel=2,
)
......@@ -593,8 +592,7 @@ class _ConfigProxy:
if attr == "_actual":
return setattr(_ConfigProxy._actual, attr, value)
warnings.warn(
"Accessing config through `aesara.configparser.config` is deprecated. "
"Use `aesara.config` instead.",
"`aesara.configparser.config` is deprecated; use `aesara.config` instead.",
DeprecationWarning,
stacklevel=2,
)
......@@ -609,12 +607,37 @@ _config = _create_default_config()
# These imports/accesses should be replaced with `aesara.config`, so this wraps
# it with warnings:
config = _ConfigProxy(_config)
# We can't alias the methods of the `config` variable above without already
# triggering the warning. Instead, we wrap the methods of the actual instance
# with warnings:
change_flags = deprecated("Use aesara.config.change_flags instead!")(
_config.change_flags
)
_config_print = deprecated("Use aesara.config.config_print instead!")(
_config.config_print
)
DEPRECATED_NAMES = [
(
"change_flags",
"`change_flags` is deprecated; use `aesara.config.change_flags` instead.",
_config.change_flags,
),
(
"_change_flags",
"`_change_flags` is deprecated; use `aesara.config.change_flags` instead.",
_config.change_flags,
),
(
"_config_print",
"`_config_print` is deprecated; use `aesara.config.config_print` instead.",
_config.config_print,
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
......@@ -2129,10 +2129,9 @@ consider_constant_ = ConsiderConstant()
def consider_constant(x):
"""
DEPRECATED: use zero_grad() or disconnected_grad() instead.
"""Consider an expression constant when computing gradients.
Consider an expression constant when computing gradients.
DEPRECATED: use `zero_grad` or `disconnected_grad` instead.
The expression itself is unaffected, but when its gradient is
computed, or the gradient of another expression that this
......@@ -2149,14 +2148,14 @@ def consider_constant(x):
"""
warnings.warn(
(
"consider_constant() is deprecated, use zero_grad() or "
"disconnected_grad() instead."
"`ConsiderConstant` is deprecated; use `zero_grad` or "
"`disconnected_grad` instead."
),
category=DeprecationWarning,
stacklevel=3,
)
return consider_constant_(x)
return ConsiderConstant()(x)
class ZeroGrad(ViewOp):
......@@ -2365,3 +2364,28 @@ def grad_scale(x, multiplier):
0.416...
"""
return GradScale(multiplier)(x)
DEPRECATED_NAMES = [
(
"consider_constant_",
"`consider_constant_` is deprecated; use `zero_grad` or `disconnected_grad` instead.",
ConsiderConstant(),
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
......@@ -177,10 +177,6 @@ class OptimizationDatabase:
print(" db", self.__db__, file=stream)
# This is deprecated and will be removed.
DB = OptimizationDatabase
class OptimizationQuery:
"""An object that specifies a set of optimizations by tag/name."""
......@@ -293,10 +289,6 @@ class OptimizationQuery:
)
# This is deprecated and will be removed.
Query = OptimizationQuery
class EquilibriumDB(OptimizationDatabase):
"""
A set of potential optimizations which should be applied in an arbitrary
......@@ -550,3 +542,33 @@ class ProxyDB(OptimizationDatabase):
def query(self, *tags, **kwtags):
return self.db.query(*tags, **kwtags)
DEPRECATED_NAMES = [
(
"DB",
"`DB` is deprecated; use `OptimizationDatabase` instead.",
OptimizationDatabase,
),
(
"Query",
"`Query` is deprecated; use `OptimizationQuery` instead.",
OptimizationQuery,
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
......@@ -22,8 +22,8 @@ from aesara.scalar.basic import (
Clip,
Composite,
Identity,
Inv,
Mul,
Reciprocal,
ScalarOp,
Second,
Switch,
......@@ -236,13 +236,15 @@ def numba_funcify_Second(op, node, **kwargs):
return second
@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
@numba_funcify.register(Reciprocal)
def numba_funcify_Reciprocal(op, node, **kwargs):
@numba_basic.numba_njit(inline="always")
def inv(x):
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x
return inv
return reciprocal
@numba_funcify.register(Sigmoid)
......
import copy
import warnings
from typing import Tuple, Union
import numpy as np
......@@ -435,14 +434,3 @@ class ChoiceFromUniform(MultinomialFromUniform):
pvals[n, m] = 0.0
pvals[n] /= pvals[n].sum()
break
class MultinomialWOReplacementFromUniform(ChoiceFromUniform):
def __init__(self, *args, **kwargs):
warnings.warn(
"MultinomialWOReplacementFromUniform is deprecated, "
"use ChoiceFromUniform instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
......@@ -1107,10 +1107,10 @@ class MRG_RandomStream:
**kwargs,
):
warnings.warn(
"MRG_RandomStream.multinomial_wo_replacement is "
"deprecated and will be removed in the next release of "
"Aesara. Please use MRG_RandomStream.choice instead.",
"`MRG_RandomStream.multinomial_wo_replacement` is "
"deprecated; use `MRG_RandomStream.choice` instead.",
DeprecationWarning,
stacklevel=2,
)
assert size is None
return self.choice(
......
......@@ -670,10 +670,6 @@ class ScalarType(CType, HasDataType, HasShape):
return shape_info
# Deprecated alias for backward compatibility
Scalar = ScalarType
def get_scalar_type(dtype, cache: Dict[str, ScalarType] = {}) -> ScalarType:
"""
Return a ScalarType(dtype) object.
......@@ -2903,10 +2899,6 @@ class Reciprocal(UnaryScalarOp):
reciprocal = Reciprocal(upgrade_to_float, name="reciprocal")
# These are deprecated and will be removed
Inv = Reciprocal
inv = reciprocal
class Log(UnaryScalarOp):
"""
......@@ -4455,3 +4447,26 @@ def handle_composite(node, mapping):
Compositef32.special[Composite] = handle_composite
DEPRECATED_NAMES = [
("Inv", "`Inv` is deprecated; use `Reciprocal` instead.", Reciprocal),
("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal),
("Scalar", "`Scalar` is deprecated; use `ScalarType` instead.", ScalarType),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
......@@ -2667,7 +2667,7 @@ def is_flat(var, ndim=None, outdim=None):
elif outdim is not None and ndim is not None:
raise ValueError("You should only specify ndim")
elif outdim is not None:
warnings.warn("flatten outdim parameter is deprecated, use ndim instead.")
warnings.warn("outdim` is deprecated; use `ndim` instead.")
ndim = outdim
return var.ndim == ndim
......
......@@ -1048,10 +1048,6 @@ def abs(a):
"""|`a`|"""
# These are deprecated and will be removed
abs_ = abs
pprint.assign(abs, printing.PatternPrinter(("|%(0)s|", -1000)))
......@@ -1080,10 +1076,6 @@ def reciprocal(a):
"""1.0/a"""
# This is deprecated and will be removed
inv = reciprocal
@scalar_elemwise
def log(a):
"""base e logarithm of a"""
......@@ -3024,13 +3016,11 @@ __all__ = [
"invert",
"bitwise_not",
"abs",
"abs_",
"exp",
"exp2",
"expm1",
"neg",
"reciprocal",
"inv",
"log",
"log2",
"log10",
......@@ -3127,3 +3117,28 @@ __all__ = [
"logaddexp",
"logsumexp",
]
DEPRECATED_NAMES = [
("abs_", "`abs_` is deprecated; use `abs` instead.", abs),
("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
return sorted(__all__ + [names[0] for names in DEPRECATED_NAMES])
......@@ -46,12 +46,13 @@ def conv2d(
subsample=(1, 1),
**kargs,
):
"""
Deprecated, old conv2d interface.
This function will build the symbolic graph for convolving a stack of
input images with a set of filters. The implementation is modelled after
Convolutional Neural Networks (CNN). It is simply a wrapper to the ConvOp
but provides a much cleaner interface.
"""Build the symbolic graph for convolving a stack of input images with a set of filters.
The implementation is modelled after Convolutional Neural Networks
(CNN). It is simply a wrapper to the `ConvOp` but provides a much cleaner
interface.
This is deprecated.
Parameters
----------
......@@ -402,8 +403,7 @@ class ConvOp(OpenMPOp):
# with s=1 for mode=='full' and s=-1 for mode=='valid'.
# To support symbolic shapes, we express this with integer arithmetic.
warnings.warn(
"The method `getOutputShape` is deprecated use"
"`get_conv_output_shape` instead.",
"`getOutputShape` is deprecated; use `get_conv_output_shape` instead.",
DeprecationWarning,
stacklevel=2,
)
......
......@@ -101,9 +101,8 @@ class Cholesky(Op):
def conjugate_solve_triangular(outer, inner):
"""Computes L^{-T} P L^{-1} for lower-triangular L."""
return solve_upper_triangular(
outer.T, solve_upper_triangular(outer.T, inner.T).T
)
solve_upper = SolveTriangular(lower=False)
return solve_upper(outer.T, solve_upper(outer.T, inner.T).T)
s = conjugate_solve_triangular(
chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz))
......@@ -507,15 +506,6 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True):
)(a, b)
# TODO: These are deprecated; emit a warning
solve_lower_triangular = SolveTriangular(lower=True)
solve_upper_triangular = SolveTriangular(lower=False)
solve_symmetric = Solve(assume_a="sym")
# TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten)
class Eigvalsh(Op):
"""
Generalized eigenvalues of a Hermitian positive definite eigensystem.
......@@ -748,10 +738,45 @@ expm = Expm()
__all__ = [
"cholesky",
"solve",
"solve_lower_triangular",
"solve_upper_triangular",
"solve_symmetric",
"eigvalsh",
"kron",
"expm",
]
DEPRECATED_NAMES = [
(
"solve_lower_triangular",
"`solve_lower_triangular` is deprecated; use `solve` instead.",
SolveTriangular(lower=True),
),
(
"solve_upper_triangular",
"`solve_upper_triangular` is deprecated; use `solve` instead.",
SolveTriangular(lower=False),
),
(
"solve_symmetric",
"`solve_symmetric` is deprecated; use `solve` instead.",
Solve(assume_a="sym"),
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
return sorted(__all__ + [names[0] for names in DEPRECATED_NAMES])
......@@ -158,12 +158,18 @@ def deprecated(message: str = ""):
def decorator_wrapper(func):
@wraps(func)
def function_wrapper(*args, **kwargs):
nonlocal message
current_call_source = "|".join(
traceback.format_stack(inspect.currentframe())
)
if current_call_source not in function_wrapper.last_call_source:
if not message:
message = f"Function {func.__name__} is deprecated."
warnings.warn(
"Function {} is now deprecated! {}".format(func.__name__, message),
message,
category=DeprecationWarning,
stacklevel=2,
)
......
......@@ -827,8 +827,8 @@ def test_Cast(v, dtype):
(set_test_value(at.iscalar(), np.array(10, dtype="int32")), aesb.float64),
],
)
def test_Inv(v, dtype):
g = aesb.inv(v)
def test_reciprocal(v, dtype):
g = aesb.reciprocal(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
......
......@@ -157,7 +157,7 @@ class TestFunction:
p = fmatrix()
n = iscalar()
with pytest.warns(DeprecationWarning):
with pytest.deprecated_call():
m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True)
......@@ -181,7 +181,7 @@ class TestFunction:
p = fmatrix()
n = iscalar()
with pytest.warns(DeprecationWarning):
with pytest.deprecated_call():
m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True)
......
import contextlib
import os
import sys
import time
......@@ -332,12 +333,20 @@ def test_broadcastable():
# the sizes of them are implicitly defined with "pvals" argument.
if distribution in [R.multinomial, R.multinomial_wo_replacement]:
# check when all dimensions are constant
uu = distribution(pvals=pvals_1)
assert uu.broadcastable == (False, True)
context_mgr = (
pytest.deprecated_call()
if distribution == R.multinomial_wo_replacement
else contextlib.suppress()
)
with context_mgr:
uu = distribution(pvals=pvals_1)
assert uu.broadcastable == (False, True)
# check when some dimensions are aesara variables
uu = distribution(pvals=pvals_2)
assert uu.broadcastable == (False, True)
with context_mgr:
uu = distribution(pvals=pvals_2)
assert uu.broadcastable == (False, True)
else:
# check when all dimensions are constant
uu = distribution(size=size1)
......@@ -1109,9 +1118,10 @@ def test_target_parameter():
basic_target_parameter_test(
srng.choice(p=pvals.astype("float32"), replace=False, target="cpu")
)
basic_target_parameter_test(
srng.multinomial_wo_replacement(pvals=pvals.astype("float32"), target="cpu")
)
with pytest.deprecated_call():
basic_target_parameter_test(
srng.multinomial_wo_replacement(pvals=pvals.astype("float32"), target="cpu")
)
@config.change_flags(compute_test_value="off")
......
......@@ -1321,16 +1321,9 @@ class TestJoinAndSplit:
def test_stack_new_interface(self):
# Test the new numpy-like interface: stack(tensors, axis=0).
# Testing against old interface
warnings.simplefilter("always", DeprecationWarning)
a = imatrix("a")
b = imatrix("b")
s1 = stack(a, b)
s2 = stack([a, b])
f = function([a, b], [s1, s2], mode=self.mode)
v1, v2 = f([[1, 2]], [[3, 4]])
assert v1.shape == v2.shape
assert np.all(v1 == v2)
# Testing axis parameter
s3 = stack([a, b], 1)
f = function([a, b], s3, mode=self.mode)
......
......@@ -14,8 +14,6 @@ from aesara.gradient import (
NullTypeGradError,
Rop,
UndefinedGrad,
consider_constant,
consider_constant_,
disconnected_grad,
disconnected_grad_,
grad,
......@@ -769,37 +767,45 @@ def test_subgraph_grad():
class TestConsiderConstant:
def setup_method(self):
self.rng = np.random.default_rng(seed=utt.fetch_seed())
def test_op_removed(self):
from aesara.gradient import ConsiderConstant, consider_constant
x = matrix("x")
y = x * consider_constant(x)
with pytest.deprecated_call():
y = x * consider_constant(x)
f = aesara.function([x], y)
# need to refer to aesara.consider_constant_ here,
# aesara.consider_constant is a wrapper function!
assert consider_constant_ not in [node.op for node in f.maker.fgraph.toposort()]
assert ConsiderConstant not in [
type(node.op) for node in f.maker.fgraph.toposort()
]
def test_grad(self):
a = np.asarray(self.rng.standard_normal((5, 5)), dtype=config.floatX)
from aesara.gradient import consider_constant
x = matrix("x")
rng = np.random.default_rng(seed=utt.fetch_seed())
expressions_gradients = [
(x * consider_constant(x), x),
(x * consider_constant(exp(x)), exp(x)),
(consider_constant(x), at.constant(0.0)),
(x**2 * consider_constant(x), 2 * x**2),
]
a = np.asarray(rng.standard_normal((5, 5)), dtype=config.floatX)
for expr, expr_grad in expressions_gradients:
g = grad(expr.sum(), x)
# gradient according to aesara
f = aesara.function([x], g, on_unused_input="ignore")
# desired gradient
f2 = aesara.function([x], expr_grad, on_unused_input="ignore")
x = matrix("x")
assert np.allclose(f(a), f2(a))
with pytest.deprecated_call():
expressions_gradients = [
(x * consider_constant(x), x),
(x * consider_constant(exp(x)), exp(x)),
(consider_constant(x), at.constant(0.0)),
(x**2 * consider_constant(x), 2 * x**2),
]
for expr, expr_grad in expressions_gradients:
g = grad(expr.sum(), x)
# gradient according to aesara
f = aesara.function([x], g, on_unused_input="ignore")
# desired gradient
f2 = aesara.function([x], expr_grad, on_unused_input="ignore")
assert np.allclose(f(a), f2(a))
class TestZeroGrad:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论