提交 ab304cb9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Deprecate use of "default" and Variable as OpFromGrah overrides

上级 6dfc811f
......@@ -2,10 +2,10 @@
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from copy import copy
from functools import partial
from typing import cast
from typing import Union, cast
import pytensor.tensor as pt
from pytensor.compile.function import function
......@@ -225,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph):
e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2])
Example 3 override L_op
Example 3 override second output of L_op
.. code-block:: python
......@@ -241,7 +241,7 @@ class OpFromGraph(Op, HasInnerGraph):
op = OpFromGraph(
[x, y, z],
[e],
lop_overrides=['default', rescale_dy, 'default'],
lop_overrides=[None, rescale_dy, None],
)
e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z])
......@@ -253,7 +253,7 @@ class OpFromGraph(Op, HasInnerGraph):
TYPE_ERR_MSG = (
"L_op/gradient override should be (single or list of)"
"'default' | OpFromGraph | callable | Variable "
"None | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got %s"
)
STYPE_ERR_MSG = (
......@@ -308,9 +308,9 @@ class OpFromGraph(Op, HasInnerGraph):
outputs: list[Variable],
*,
inline: bool = False,
lop_overrides: str = "default",
grad_overrides: str = "default",
rop_overrides: str = "default",
lop_overrides: Union[Callable, "OpFromGraph", None] = None,
grad_overrides: Union[Callable, "OpFromGraph", None] = None,
rop_overrides: Union[Callable, "OpFromGraph", None] = None,
connection_pattern: list[list[bool]] | None = None,
strict: bool = False,
name: str | None = None,
......@@ -333,10 +333,10 @@ class OpFromGraph(Op, HasInnerGraph):
``False`` : will use a pre-compiled function inside.
grad_overrides
Defaults to ``'default'``.
Defaults to ``None``.
This argument is mutually exclusive with ``lop_overrides``.
``'default'`` : Do not override, use default grad() result
``None`` : Do not override, use default grad() result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``output_grads``
......@@ -346,14 +346,14 @@ class OpFromGraph(Op, HasInnerGraph):
Each argument is expected to be a list of :class:`Variable `.
Must return list of :class:`Variable `.
lop_overrides
Defaults to ``'default'``.
Defaults to ``None``.
This argument is mutually exclusive with ``grad_overrides``.
These options are similar to the ``grad_overrides`` above, but for
the :meth:`Op.L_op` method.
``'default'``: Do not override, use the default :meth:`Op.L_op` result
``None``: Do not override, use the default :meth:`Op.L_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs``,
......@@ -373,11 +373,11 @@ class OpFromGraph(Op, HasInnerGraph):
a specific input, length of list must be equal to number of inputs.
rop_overrides
One of ``{'default', OpFromGraph, callable, Variable}``.
One of ``{None, OpFromGraph, callable, Variable}``.
Defaults to ``'default'``.
Defaults to ``None``.
``'default'``: Do not override, use the default :meth:`Op.R_op` result
``None``: Do not override, use the default :meth:`Op.R_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``eval_points``
......@@ -446,19 +446,29 @@ class OpFromGraph(Op, HasInnerGraph):
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
for override in (lop_overrides, grad_overrides, rop_overrides):
if override == "default":
raise ValueError(
"'default' is no longer a valid value for overrides. Use None instead."
)
if isinstance(override, Variable):
raise TypeError(
"Variables are no longer valid types for overrides. Return them in a list for each output instead"
)
self.lop_overrides = lop_overrides
self.grad_overrides = grad_overrides
self.rop_overrides = rop_overrides
if lop_overrides != "default":
if grad_overrides != "default":
if lop_overrides is not None:
if grad_overrides is not None:
raise ValueError(
"lop_overrides and grad_overrides are mutually exclusive"
)
else:
self.set_lop_overrides(lop_overrides)
self._lop_type = "lop"
elif grad_overrides != "default":
elif grad_overrides is not None:
warnings.warn(
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future.",
FutureWarning,
......@@ -466,7 +476,7 @@ class OpFromGraph(Op, HasInnerGraph):
self.set_lop_overrides(grad_overrides)
self._lop_type = "grad"
else:
self.set_lop_overrides("default")
self.set_lop_overrides(None)
self._lop_type = "lop"
self.set_rop_overrides(rop_overrides)
......@@ -546,7 +556,7 @@ class OpFromGraph(Op, HasInnerGraph):
callable_args = (local_inputs, output_grads)
# we need to convert _lop_op into an OfG instance
if lop_op == "default":
if lop_op is None:
gdefaults_l = fn_grad(wrt=local_inputs)
all_grads_l, all_grads_ov_l = zip(
*[
......@@ -556,12 +566,6 @@ class OpFromGraph(Op, HasInnerGraph):
)
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
elif isinstance(lop_op, Variable):
if isinstance(lop_op.type, DisconnectedType | NullType):
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [lop_op.type() for _ in range(inp_len)]
else:
raise ValueError(self.STYPE_ERR_MSG % lop_op.type)
elif isinstance(lop_op, list):
goverrides_l = lop_op
if len(goverrides_l) != inp_len:
......@@ -571,15 +575,13 @@ class OpFromGraph(Op, HasInnerGraph):
)
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
wrt_l = [
lin for lin, gov in zip(local_inputs, goverrides_l) if gov == "default"
]
wrt_l = [lin for lin, gov in zip(local_inputs, goverrides_l) if gov is None]
gdefaults = iter(fn_grad(wrt=wrt_l) if wrt_l else [])
# combine overriding gradients
all_grads_l = []
all_grads_ov_l = []
for inp, fn_gov in zip(local_inputs, goverrides_l):
if fn_gov == "default":
if fn_gov is None:
gnext, gnext_ov = OpFromGraph._filter_grad_var(next(gdefaults), inp)
all_grads_l.append(gnext)
all_grads_ov_l.append(gnext_ov)
......@@ -652,13 +654,13 @@ class OpFromGraph(Op, HasInnerGraph):
fn_rop = partial(Rop, wrt=local_inputs, eval_points=eval_points)
TYPE_ERR_MSG = (
"R_op overrides should be (single or list of)"
"OpFromGraph | 'default' | None | 0 | callable, got %s"
"OpFromGraph, None, a list or a callable, got %s"
)
STYPE_ERR_MSG = (
"Overriding Variable instance can only have type"
" of DisconnectedType or NullType, got %s"
)
if rop_op == "default":
if rop_op is None:
rdefaults_l = fn_rop(f=local_outputs)
all_rops_l, all_rops_ov_l = zip(
*[
......@@ -668,15 +670,6 @@ class OpFromGraph(Op, HasInnerGraph):
)
all_rops_l = list(all_rops_l)
all_rops_ov_l = list(all_rops_ov_l)
elif isinstance(rop_op, Variable):
if isinstance(rop_op.type, NullType):
all_rops_l = [inp.zeros_like() for inp in local_inputs]
all_rops_ov_l = [rop_op.type() for _ in range(out_len)]
elif isinstance(rop_op.type, DisconnectedType):
all_rops_l = [inp.zeros_like() for inp in local_inputs]
all_rops_ov_l = [None] * out_len
else:
raise ValueError(STYPE_ERR_MSG % rop_op.type)
elif isinstance(rop_op, list):
roverrides_l = rop_op
if len(roverrides_l) != out_len:
......@@ -686,7 +679,7 @@ class OpFromGraph(Op, HasInnerGraph):
)
# get outputs that does not have Rop override
odefaults_l = [
lo for lo, rov in zip(local_outputs, roverrides_l) if rov == "default"
lo for lo, rov in zip(local_outputs, roverrides_l) if rov is None
]
rdefaults_l = fn_rop(f=odefaults_l)
rdefaults = iter(rdefaults_l if odefaults_l else [])
......@@ -694,7 +687,7 @@ class OpFromGraph(Op, HasInnerGraph):
all_rops_l = []
all_rops_ov_l = []
for out, fn_rov in zip(local_outputs, roverrides_l):
if fn_rov == "default":
if fn_rov is None:
rnext, rnext_ov = OpFromGraph._filter_rop_var(next(rdefaults), out)
all_rops_l.append(rnext)
all_rops_ov_l.append(rnext_ov)
......@@ -769,7 +762,6 @@ class OpFromGraph(Op, HasInnerGraph):
self._lop_op = grad_overrides
self._lop_op_is_cached = False
self._lop_type = "grad"
self._lop_is_default = grad_overrides == "default"
def set_lop_overrides(self, lop_overrides):
"""
......@@ -780,7 +772,6 @@ class OpFromGraph(Op, HasInnerGraph):
self._lop_op = lop_overrides
self._lop_op_is_cached = False
self._lop_type = "lop"
self._lop_is_default = lop_overrides == "default"
def set_rop_overrides(self, rop_overrides):
"""
......@@ -790,7 +781,6 @@ class OpFromGraph(Op, HasInnerGraph):
"""
self._rop_op = rop_overrides
self._rop_op_is_cached = False
self._rop_is_default = rop_overrides == "default"
def L_op(self, inputs, outputs, output_grads):
if not self._lop_op_is_cached:
......
......@@ -11,7 +11,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, Rop, disconnected_type, grad
from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.null_type import NullType, null_type
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
......@@ -93,6 +93,20 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert res.shape == (2, 5)
assert np.all(180.0 == res)
def test_overrides_deprecated_api(self):
inp = scalar("x")
out = inp + 1
for kwarg in ("lop_overrides", "grad_overrides", "rop_overrides"):
with pytest.raises(
ValueError, match="'default' is no longer a valid value for overrides"
):
OpFromGraph([inp], [out], **{kwarg: "default"})
with pytest.raises(
TypeError, match="Variables are no longer valid types for overrides"
):
OpFromGraph([inp], [out], **{kwarg: null_type()})
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
)
......@@ -211,9 +225,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
w, b = vectors("wb")
# we make the 3rd gradient default (no override)
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
op_linear = cls_ofg(
[x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"]
)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, None])
xx, ww, bb = vector("xx"), vector("yy"), vector("bb")
zz = pt_sum(op_linear(xx, ww, bb))
dx, dw, db = grad(zz, [xx, ww, bb])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论