提交 f67638b6 authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: Maxim Kochurov

remove ConsiderConstant

上级 09010219
...@@ -2108,44 +2108,6 @@ def _is_zero(x): ...@@ -2108,44 +2108,6 @@ def _is_zero(x):
return "yes" return "yes"
class ConsiderConstant(ViewOp):
def grad(self, args, g_outs):
return [g_out.zeros_like(g_out) for g_out in g_outs]
consider_constant_ = ConsiderConstant()
def consider_constant(x):
"""Consider an expression constant when computing gradients.
DEPRECATED: use `zero_grad` or `disconnected_grad` instead.
The expression itself is unaffected, but when its gradient is
computed, or the gradient of another expression that this
expression is a subexpression of, it will not be backpropagated
through. In other words, the gradient of the expression is
truncated to 0.
:param x: A PyTensor expression whose gradient should be truncated.
:return: The expression is returned unmodified, but its gradient
is now truncated to 0.
.. versionadded:: 0.7
"""
warnings.warn(
(
"`ConsiderConstant` is deprecated; use `zero_grad` or "
"`disconnected_grad` instead."
),
category=DeprecationWarning,
stacklevel=3,
)
return ConsiderConstant()(x)
class ZeroGrad(ViewOp): class ZeroGrad(ViewOp):
def grad(self, args, g_outs): def grad(self, args, g_outs):
return [g_out.zeros_like(g_out) for g_out in g_outs] return [g_out.zeros_like(g_out) for g_out in g_outs]
...@@ -2352,28 +2314,3 @@ def grad_scale(x, multiplier): ...@@ -2352,28 +2314,3 @@ def grad_scale(x, multiplier):
0.416... 0.416...
""" """
return GradScale(multiplier)(x) return GradScale(multiplier)(x)
DEPRECATED_NAMES = [
(
"consider_constant_",
"`consider_constant_` is deprecated; use `zero_grad` or `disconnected_grad` instead.",
ConsiderConstant(),
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
...@@ -99,7 +99,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int: ...@@ -99,7 +99,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
import pytensor.tensor.exceptions # noqa import pytensor.tensor.exceptions # noqa
from pytensor.gradient import consider_constant, grad, hessian, jacobian # noqa from pytensor.gradient import grad, hessian, jacobian # noqa
# adds shared-variable constructors # adds shared-variable constructors
from pytensor.tensor import sharedvar # noqa from pytensor.tensor import sharedvar # noqa
......
...@@ -766,48 +766,6 @@ def test_subgraph_grad(): ...@@ -766,48 +766,6 @@ def test_subgraph_grad():
assert np.sum(np.abs(true_grad - pgrad)) < 0.00001 assert np.sum(np.abs(true_grad - pgrad)) < 0.00001
class TestConsiderConstant:
def test_op_removed(self):
from pytensor.gradient import ConsiderConstant, consider_constant
x = matrix("x")
with pytest.deprecated_call():
y = x * consider_constant(x)
f = pytensor.function([x], y)
assert ConsiderConstant not in [
type(node.op) for node in f.maker.fgraph.toposort()
]
def test_grad(self):
from pytensor.gradient import consider_constant
rng = np.random.default_rng(seed=utt.fetch_seed())
a = np.asarray(rng.standard_normal((5, 5)), dtype=config.floatX)
x = matrix("x")
with pytest.deprecated_call():
expressions_gradients = [
(x * consider_constant(x), x),
(x * consider_constant(exp(x)), exp(x)),
(consider_constant(x), at.constant(0.0)),
(x**2 * consider_constant(x), 2 * x**2),
]
for expr, expr_grad in expressions_gradients:
g = grad(expr.sum(), x)
# gradient according to pytensor
f = pytensor.function([x], g, on_unused_input="ignore")
# desired gradient
f2 = pytensor.function([x], expr_grad, on_unused_input="ignore")
assert np.allclose(f(a), f2(a))
class TestZeroGrad: class TestZeroGrad:
def setup_method(self): def setup_method(self):
self.rng = np.random.default_rng(seed=utt.fetch_seed()) self.rng = np.random.default_rng(seed=utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论