提交 c736927b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up deprecations

This commit introduces the use of module-level `__getattr__` overrides to emit deprecation warnings for renamed objects. It also adds some missing `pytest.deprecated_call` checks.
上级 e0d91807
...@@ -62,12 +62,6 @@ for p in sys.path: ...@@ -62,12 +62,6 @@ for p in sys.path:
raise RuntimeError("You have the aesara directory in your Python path.") raise RuntimeError("You have the aesara directory in your Python path.")
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.utils import deprecated
change_flags = deprecated("Use aesara.config.change_flags instead!")(
config.change_flags
)
# This is the api version for ops that generate C code. External ops # This is the api version for ops that generate C code. External ops
...@@ -178,3 +172,27 @@ from aesara.scan.views import foldl, foldr, map, reduce ...@@ -178,3 +172,27 @@ from aesara.scan.views import foldl, foldr, map, reduce
# imports were executed, we can warn about remaining flags provided by the user # imports were executed, we can warn about remaining flags provided by the user
# through AESARA_FLAGS. # through AESARA_FLAGS.
config.warn_unused_flags() config.warn_unused_flags()
DEPRECATED_NAMES = [
(
"change_flags",
"`aesara.change_flags` is deprecated: use `aesara.config.change_flags` instead.",
config.change_flags,
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
...@@ -14,7 +14,7 @@ from functools import wraps ...@@ -14,7 +14,7 @@ from functools import wraps
from io import StringIO from io import StringIO
from typing import Callable, Dict, Optional, Sequence, Union from typing import Callable, Dict, Optional, Sequence, Union
from aesara.utils import deprecated, hash_from_code from aesara.utils import hash_from_code
_logger = logging.getLogger("aesara.configparser") _logger = logging.getLogger("aesara.configparser")
...@@ -582,8 +582,7 @@ class _ConfigProxy: ...@@ -582,8 +582,7 @@ class _ConfigProxy:
if attr == "_actual": if attr == "_actual":
return _ConfigProxy._actual return _ConfigProxy._actual
warnings.warn( warnings.warn(
"Accessing config through `aesara.configparser.config` is deprecated. " "`aesara.configparser.config` is deprecated; use `aesara.config` instead.",
"Use `aesara.config` instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
...@@ -593,8 +592,7 @@ class _ConfigProxy: ...@@ -593,8 +592,7 @@ class _ConfigProxy:
if attr == "_actual": if attr == "_actual":
return setattr(_ConfigProxy._actual, attr, value) return setattr(_ConfigProxy._actual, attr, value)
warnings.warn( warnings.warn(
"Accessing config through `aesara.configparser.config` is deprecated. " "`aesara.configparser.config` is deprecated; use `aesara.config` instead.",
"Use `aesara.config` instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
...@@ -609,12 +607,37 @@ _config = _create_default_config() ...@@ -609,12 +607,37 @@ _config = _create_default_config()
# These imports/accesses should be replaced with `aesara.config`, so this wraps # These imports/accesses should be replaced with `aesara.config`, so this wraps
# it with warnings: # it with warnings:
config = _ConfigProxy(_config) config = _ConfigProxy(_config)
# We can't alias the methods of the `config` variable above without already
# triggering the warning. Instead, we wrap the methods of the actual instance DEPRECATED_NAMES = [
# with warnings: (
change_flags = deprecated("Use aesara.config.change_flags instead!")( "change_flags",
_config.change_flags "`change_flags` is deprecated; use `aesara.config.change_flags` instead.",
) _config.change_flags,
_config_print = deprecated("Use aesara.config.config_print instead!")( ),
_config.config_print (
) "_change_flags",
"`_change_flags` is deprecated; use `aesara.config.change_flags` instead.",
_config.change_flags,
),
(
"_config_print",
"`_config_print` is deprecated; use `aesara.config.config_print` instead.",
_config.config_print,
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
...@@ -2129,10 +2129,9 @@ consider_constant_ = ConsiderConstant() ...@@ -2129,10 +2129,9 @@ consider_constant_ = ConsiderConstant()
def consider_constant(x): def consider_constant(x):
""" """Consider an expression constant when computing gradients.
DEPRECATED: use zero_grad() or disconnected_grad() instead.
Consider an expression constant when computing gradients. DEPRECATED: use `zero_grad` or `disconnected_grad` instead.
The expression itself is unaffected, but when its gradient is The expression itself is unaffected, but when its gradient is
computed, or the gradient of another expression that this computed, or the gradient of another expression that this
...@@ -2149,14 +2148,14 @@ def consider_constant(x): ...@@ -2149,14 +2148,14 @@ def consider_constant(x):
""" """
warnings.warn( warnings.warn(
( (
"consider_constant() is deprecated, use zero_grad() or " "`ConsiderConstant` is deprecated; use `zero_grad` or "
"disconnected_grad() instead." "`disconnected_grad` instead."
), ),
category=DeprecationWarning, category=DeprecationWarning,
stacklevel=3, stacklevel=3,
) )
return consider_constant_(x) return ConsiderConstant()(x)
class ZeroGrad(ViewOp): class ZeroGrad(ViewOp):
...@@ -2365,3 +2364,28 @@ def grad_scale(x, multiplier): ...@@ -2365,3 +2364,28 @@ def grad_scale(x, multiplier):
0.416... 0.416...
""" """
return GradScale(multiplier)(x) return GradScale(multiplier)(x)
DEPRECATED_NAMES = [
(
"consider_constant_",
"`consider_constant_` is deprecated; use `zero_grad` or `disconnected_grad` instead.",
ConsiderConstant(),
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
...@@ -177,10 +177,6 @@ class OptimizationDatabase: ...@@ -177,10 +177,6 @@ class OptimizationDatabase:
print(" db", self.__db__, file=stream) print(" db", self.__db__, file=stream)
# This is deprecated and will be removed.
DB = OptimizationDatabase
class OptimizationQuery: class OptimizationQuery:
"""An object that specifies a set of optimizations by tag/name.""" """An object that specifies a set of optimizations by tag/name."""
...@@ -293,10 +289,6 @@ class OptimizationQuery: ...@@ -293,10 +289,6 @@ class OptimizationQuery:
) )
# This is deprecated and will be removed.
Query = OptimizationQuery
class EquilibriumDB(OptimizationDatabase): class EquilibriumDB(OptimizationDatabase):
""" """
A set of potential optimizations which should be applied in an arbitrary A set of potential optimizations which should be applied in an arbitrary
...@@ -550,3 +542,33 @@ class ProxyDB(OptimizationDatabase): ...@@ -550,3 +542,33 @@ class ProxyDB(OptimizationDatabase):
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
return self.db.query(*tags, **kwtags) return self.db.query(*tags, **kwtags)
DEPRECATED_NAMES = [
(
"DB",
"`DB` is deprecated; use `OptimizationDatabase` instead.",
OptimizationDatabase,
),
(
"Query",
"`Query` is deprecated; use `OptimizationQuery` instead.",
OptimizationQuery,
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
...@@ -22,8 +22,8 @@ from aesara.scalar.basic import ( ...@@ -22,8 +22,8 @@ from aesara.scalar.basic import (
Clip, Clip,
Composite, Composite,
Identity, Identity,
Inv,
Mul, Mul,
Reciprocal,
ScalarOp, ScalarOp,
Second, Second,
Switch, Switch,
...@@ -236,13 +236,15 @@ def numba_funcify_Second(op, node, **kwargs): ...@@ -236,13 +236,15 @@ def numba_funcify_Second(op, node, **kwargs):
return second return second
@numba_funcify.register(Inv) @numba_funcify.register(Reciprocal)
def numba_funcify_Inv(op, node, **kwargs): def numba_funcify_Reciprocal(op, node, **kwargs):
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def inv(x): def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x return 1 / x
return inv return reciprocal
@numba_funcify.register(Sigmoid) @numba_funcify.register(Sigmoid)
......
import copy import copy
import warnings
from typing import Tuple, Union from typing import Tuple, Union
import numpy as np import numpy as np
...@@ -435,14 +434,3 @@ class ChoiceFromUniform(MultinomialFromUniform): ...@@ -435,14 +434,3 @@ class ChoiceFromUniform(MultinomialFromUniform):
pvals[n, m] = 0.0 pvals[n, m] = 0.0
pvals[n] /= pvals[n].sum() pvals[n] /= pvals[n].sum()
break break
class MultinomialWOReplacementFromUniform(ChoiceFromUniform):
def __init__(self, *args, **kwargs):
warnings.warn(
"MultinomialWOReplacementFromUniform is deprecated, "
"use ChoiceFromUniform instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
...@@ -1107,10 +1107,10 @@ class MRG_RandomStream: ...@@ -1107,10 +1107,10 @@ class MRG_RandomStream:
**kwargs, **kwargs,
): ):
warnings.warn( warnings.warn(
"MRG_RandomStream.multinomial_wo_replacement is " "`MRG_RandomStream.multinomial_wo_replacement` is "
"deprecated and will be removed in the next release of " "deprecated; use `MRG_RandomStream.choice` instead.",
"Aesara. Please use MRG_RandomStream.choice instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2,
) )
assert size is None assert size is None
return self.choice( return self.choice(
......
...@@ -670,10 +670,6 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -670,10 +670,6 @@ class ScalarType(CType, HasDataType, HasShape):
return shape_info return shape_info
# Deprecated alias for backward compatibility
Scalar = ScalarType
def get_scalar_type(dtype, cache: Dict[str, ScalarType] = {}) -> ScalarType: def get_scalar_type(dtype, cache: Dict[str, ScalarType] = {}) -> ScalarType:
""" """
Return a ScalarType(dtype) object. Return a ScalarType(dtype) object.
...@@ -2903,10 +2899,6 @@ class Reciprocal(UnaryScalarOp): ...@@ -2903,10 +2899,6 @@ class Reciprocal(UnaryScalarOp):
reciprocal = Reciprocal(upgrade_to_float, name="reciprocal") reciprocal = Reciprocal(upgrade_to_float, name="reciprocal")
# These are deprecated and will be removed
Inv = Reciprocal
inv = reciprocal
class Log(UnaryScalarOp): class Log(UnaryScalarOp):
""" """
...@@ -4455,3 +4447,26 @@ def handle_composite(node, mapping): ...@@ -4455,3 +4447,26 @@ def handle_composite(node, mapping):
Compositef32.special[Composite] = handle_composite Compositef32.special[Composite] = handle_composite
DEPRECATED_NAMES = [
("Inv", "`Inv` is deprecated; use `Reciprocal` instead.", Reciprocal),
("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal),
("Scalar", "`Scalar` is deprecated; use `ScalarType` instead.", ScalarType),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
...@@ -2667,7 +2667,7 @@ def is_flat(var, ndim=None, outdim=None): ...@@ -2667,7 +2667,7 @@ def is_flat(var, ndim=None, outdim=None):
elif outdim is not None and ndim is not None: elif outdim is not None and ndim is not None:
raise ValueError("You should only specify ndim") raise ValueError("You should only specify ndim")
elif outdim is not None: elif outdim is not None:
warnings.warn("flatten outdim parameter is deprecated, use ndim instead.") warnings.warn("outdim` is deprecated; use `ndim` instead.")
ndim = outdim ndim = outdim
return var.ndim == ndim return var.ndim == ndim
......
...@@ -1048,10 +1048,6 @@ def abs(a): ...@@ -1048,10 +1048,6 @@ def abs(a):
"""|`a`|""" """|`a`|"""
# These are deprecated and will be removed
abs_ = abs
pprint.assign(abs, printing.PatternPrinter(("|%(0)s|", -1000))) pprint.assign(abs, printing.PatternPrinter(("|%(0)s|", -1000)))
...@@ -1080,10 +1076,6 @@ def reciprocal(a): ...@@ -1080,10 +1076,6 @@ def reciprocal(a):
"""1.0/a""" """1.0/a"""
# This is deprecated and will be removed
inv = reciprocal
@scalar_elemwise @scalar_elemwise
def log(a): def log(a):
"""base e logarithm of a""" """base e logarithm of a"""
...@@ -3024,13 +3016,11 @@ __all__ = [ ...@@ -3024,13 +3016,11 @@ __all__ = [
"invert", "invert",
"bitwise_not", "bitwise_not",
"abs", "abs",
"abs_",
"exp", "exp",
"exp2", "exp2",
"expm1", "expm1",
"neg", "neg",
"reciprocal", "reciprocal",
"inv",
"log", "log",
"log2", "log2",
"log10", "log10",
...@@ -3127,3 +3117,28 @@ __all__ = [ ...@@ -3127,3 +3117,28 @@ __all__ = [
"logaddexp", "logaddexp",
"logsumexp", "logsumexp",
] ]
DEPRECATED_NAMES = [
("abs_", "`abs_` is deprecated; use `abs` instead.", abs),
("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
return sorted(__all__ + [names[0] for names in DEPRECATED_NAMES])
...@@ -46,12 +46,13 @@ def conv2d( ...@@ -46,12 +46,13 @@ def conv2d(
subsample=(1, 1), subsample=(1, 1),
**kargs, **kargs,
): ):
""" """Build the symbolic graph for convolving a stack of input images with a set of filters.
Deprecated, old conv2d interface.
This function will build the symbolic graph for convolving a stack of The implementation is modelled after Convolutional Neural Networks
input images with a set of filters. The implementation is modelled after (CNN). It is simply a wrapper to the `ConvOp` but provides a much cleaner
Convolutional Neural Networks (CNN). It is simply a wrapper to the ConvOp interface.
but provides a much cleaner interface.
This is deprecated.
Parameters Parameters
---------- ----------
...@@ -402,8 +403,7 @@ class ConvOp(OpenMPOp): ...@@ -402,8 +403,7 @@ class ConvOp(OpenMPOp):
# with s=1 for mode=='full' and s=-1 for mode=='valid'. # with s=1 for mode=='full' and s=-1 for mode=='valid'.
# To support symbolic shapes, we express this with integer arithmetic. # To support symbolic shapes, we express this with integer arithmetic.
warnings.warn( warnings.warn(
"The method `getOutputShape` is deprecated use" "`getOutputShape` is deprecated; use `get_conv_output_shape` instead.",
"`get_conv_output_shape` instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
......
...@@ -101,9 +101,8 @@ class Cholesky(Op): ...@@ -101,9 +101,8 @@ class Cholesky(Op):
def conjugate_solve_triangular(outer, inner): def conjugate_solve_triangular(outer, inner):
"""Computes L^{-T} P L^{-1} for lower-triangular L.""" """Computes L^{-T} P L^{-1} for lower-triangular L."""
return solve_upper_triangular( solve_upper = SolveTriangular(lower=False)
outer.T, solve_upper_triangular(outer.T, inner.T).T return solve_upper(outer.T, solve_upper(outer.T, inner.T).T)
)
s = conjugate_solve_triangular( s = conjugate_solve_triangular(
chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz)) chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz))
...@@ -507,15 +506,6 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True): ...@@ -507,15 +506,6 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True):
)(a, b) )(a, b)
# TODO: These are deprecated; emit a warning
solve_lower_triangular = SolveTriangular(lower=True)
solve_upper_triangular = SolveTriangular(lower=False)
solve_symmetric = Solve(assume_a="sym")
# TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten)
class Eigvalsh(Op): class Eigvalsh(Op):
""" """
Generalized eigenvalues of a Hermitian positive definite eigensystem. Generalized eigenvalues of a Hermitian positive definite eigensystem.
...@@ -748,10 +738,45 @@ expm = Expm() ...@@ -748,10 +738,45 @@ expm = Expm()
__all__ = [ __all__ = [
"cholesky", "cholesky",
"solve", "solve",
"solve_lower_triangular",
"solve_upper_triangular",
"solve_symmetric",
"eigvalsh", "eigvalsh",
"kron", "kron",
"expm", "expm",
] ]
DEPRECATED_NAMES = [
(
"solve_lower_triangular",
"`solve_lower_triangular` is deprecated; use `solve` instead.",
SolveTriangular(lower=True),
),
(
"solve_upper_triangular",
"`solve_upper_triangular` is deprecated; use `solve` instead.",
SolveTriangular(lower=False),
),
(
"solve_symmetric",
"`solve_symmetric` is deprecated; use `solve` instead.",
Solve(assume_a="sym"),
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
return sorted(__all__ + [names[0] for names in DEPRECATED_NAMES])
...@@ -158,12 +158,18 @@ def deprecated(message: str = ""): ...@@ -158,12 +158,18 @@ def deprecated(message: str = ""):
def decorator_wrapper(func): def decorator_wrapper(func):
@wraps(func) @wraps(func)
def function_wrapper(*args, **kwargs): def function_wrapper(*args, **kwargs):
nonlocal message
current_call_source = "|".join( current_call_source = "|".join(
traceback.format_stack(inspect.currentframe()) traceback.format_stack(inspect.currentframe())
) )
if current_call_source not in function_wrapper.last_call_source: if current_call_source not in function_wrapper.last_call_source:
if not message:
message = f"Function {func.__name__} is deprecated."
warnings.warn( warnings.warn(
"Function {} is now deprecated! {}".format(func.__name__, message), message,
category=DeprecationWarning, category=DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
......
...@@ -827,8 +827,8 @@ def test_Cast(v, dtype): ...@@ -827,8 +827,8 @@ def test_Cast(v, dtype):
(set_test_value(at.iscalar(), np.array(10, dtype="int32")), aesb.float64), (set_test_value(at.iscalar(), np.array(10, dtype="int32")), aesb.float64),
], ],
) )
def test_Inv(v, dtype): def test_reciprocal(v, dtype):
g = aesb.inv(v) g = aesb.reciprocal(v)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, g_fg,
......
...@@ -157,7 +157,7 @@ class TestFunction: ...@@ -157,7 +157,7 @@ class TestFunction:
p = fmatrix() p = fmatrix()
n = iscalar() n = iscalar()
with pytest.warns(DeprecationWarning): with pytest.deprecated_call():
m = th_rng.multinomial_wo_replacement(pvals=p, n=n) m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True) f = function([p, n], m, allow_input_downcast=True)
...@@ -181,7 +181,7 @@ class TestFunction: ...@@ -181,7 +181,7 @@ class TestFunction:
p = fmatrix() p = fmatrix()
n = iscalar() n = iscalar()
with pytest.warns(DeprecationWarning): with pytest.deprecated_call():
m = th_rng.multinomial_wo_replacement(pvals=p, n=n) m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True) f = function([p, n], m, allow_input_downcast=True)
......
import contextlib
import os import os
import sys import sys
import time import time
...@@ -332,12 +333,20 @@ def test_broadcastable(): ...@@ -332,12 +333,20 @@ def test_broadcastable():
# the sizes of them are implicitly defined with "pvals" argument. # the sizes of them are implicitly defined with "pvals" argument.
if distribution in [R.multinomial, R.multinomial_wo_replacement]: if distribution in [R.multinomial, R.multinomial_wo_replacement]:
# check when all dimensions are constant # check when all dimensions are constant
uu = distribution(pvals=pvals_1) context_mgr = (
assert uu.broadcastable == (False, True) pytest.deprecated_call()
if distribution == R.multinomial_wo_replacement
else contextlib.suppress()
)
with context_mgr:
uu = distribution(pvals=pvals_1)
assert uu.broadcastable == (False, True)
# check when some dimensions are aesara variables # check when some dimensions are aesara variables
uu = distribution(pvals=pvals_2) with context_mgr:
assert uu.broadcastable == (False, True) uu = distribution(pvals=pvals_2)
assert uu.broadcastable == (False, True)
else: else:
# check when all dimensions are constant # check when all dimensions are constant
uu = distribution(size=size1) uu = distribution(size=size1)
...@@ -1109,9 +1118,10 @@ def test_target_parameter(): ...@@ -1109,9 +1118,10 @@ def test_target_parameter():
basic_target_parameter_test( basic_target_parameter_test(
srng.choice(p=pvals.astype("float32"), replace=False, target="cpu") srng.choice(p=pvals.astype("float32"), replace=False, target="cpu")
) )
basic_target_parameter_test( with pytest.deprecated_call():
srng.multinomial_wo_replacement(pvals=pvals.astype("float32"), target="cpu") basic_target_parameter_test(
) srng.multinomial_wo_replacement(pvals=pvals.astype("float32"), target="cpu")
)
@config.change_flags(compute_test_value="off") @config.change_flags(compute_test_value="off")
......
...@@ -1321,16 +1321,9 @@ class TestJoinAndSplit: ...@@ -1321,16 +1321,9 @@ class TestJoinAndSplit:
def test_stack_new_interface(self): def test_stack_new_interface(self):
# Test the new numpy-like interface: stack(tensors, axis=0). # Test the new numpy-like interface: stack(tensors, axis=0).
# Testing against old interface
warnings.simplefilter("always", DeprecationWarning)
a = imatrix("a") a = imatrix("a")
b = imatrix("b") b = imatrix("b")
s1 = stack(a, b)
s2 = stack([a, b])
f = function([a, b], [s1, s2], mode=self.mode)
v1, v2 = f([[1, 2]], [[3, 4]])
assert v1.shape == v2.shape
assert np.all(v1 == v2)
# Testing axis parameter # Testing axis parameter
s3 = stack([a, b], 1) s3 = stack([a, b], 1)
f = function([a, b], s3, mode=self.mode) f = function([a, b], s3, mode=self.mode)
......
...@@ -14,8 +14,6 @@ from aesara.gradient import ( ...@@ -14,8 +14,6 @@ from aesara.gradient import (
NullTypeGradError, NullTypeGradError,
Rop, Rop,
UndefinedGrad, UndefinedGrad,
consider_constant,
consider_constant_,
disconnected_grad, disconnected_grad,
disconnected_grad_, disconnected_grad_,
grad, grad,
...@@ -769,37 +767,45 @@ def test_subgraph_grad(): ...@@ -769,37 +767,45 @@ def test_subgraph_grad():
class TestConsiderConstant: class TestConsiderConstant:
def setup_method(self):
self.rng = np.random.default_rng(seed=utt.fetch_seed())
def test_op_removed(self): def test_op_removed(self):
from aesara.gradient import ConsiderConstant, consider_constant
x = matrix("x") x = matrix("x")
y = x * consider_constant(x)
with pytest.deprecated_call():
y = x * consider_constant(x)
f = aesara.function([x], y) f = aesara.function([x], y)
# need to refer to aesara.consider_constant_ here,
# aesara.consider_constant is a wrapper function! assert ConsiderConstant not in [
assert consider_constant_ not in [node.op for node in f.maker.fgraph.toposort()] type(node.op) for node in f.maker.fgraph.toposort()
]
def test_grad(self): def test_grad(self):
a = np.asarray(self.rng.standard_normal((5, 5)), dtype=config.floatX) from aesara.gradient import consider_constant
x = matrix("x") rng = np.random.default_rng(seed=utt.fetch_seed())
expressions_gradients = [ a = np.asarray(rng.standard_normal((5, 5)), dtype=config.floatX)
(x * consider_constant(x), x),
(x * consider_constant(exp(x)), exp(x)),
(consider_constant(x), at.constant(0.0)),
(x**2 * consider_constant(x), 2 * x**2),
]
for expr, expr_grad in expressions_gradients: x = matrix("x")
g = grad(expr.sum(), x)
# gradient according to aesara
f = aesara.function([x], g, on_unused_input="ignore")
# desired gradient
f2 = aesara.function([x], expr_grad, on_unused_input="ignore")
assert np.allclose(f(a), f2(a)) with pytest.deprecated_call():
expressions_gradients = [
(x * consider_constant(x), x),
(x * consider_constant(exp(x)), exp(x)),
(consider_constant(x), at.constant(0.0)),
(x**2 * consider_constant(x), 2 * x**2),
]
for expr, expr_grad in expressions_gradients:
g = grad(expr.sum(), x)
# gradient according to aesara
f = aesara.function([x], g, on_unused_input="ignore")
# desired gradient
f2 = aesara.function([x], expr_grad, on_unused_input="ignore")
assert np.allclose(f(a), f2(a))
class TestZeroGrad: class TestZeroGrad:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论