提交 ab304cb9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Deprecate use of "default" and Variable as OpFromGrah overrides

上级 6dfc811f
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Callable, Sequence
from copy import copy from copy import copy
from functools import partial from functools import partial
from typing import cast from typing import Union, cast
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.function import function from pytensor.compile.function import function
...@@ -225,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -225,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph):
e2 = op(x, y, z) + op(z, y, x) e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2]) fn = function([x, y, z], [e2])
Example 3 override L_op Example 3 override second output of L_op
.. code-block:: python .. code-block:: python
...@@ -241,7 +241,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -241,7 +241,7 @@ class OpFromGraph(Op, HasInnerGraph):
op = OpFromGraph( op = OpFromGraph(
[x, y, z], [x, y, z],
[e], [e],
lop_overrides=['default', rescale_dy, 'default'], lop_overrides=[None, rescale_dy, None],
) )
e2 = op(x, y, z) e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z]) dx, dy, dz = grad(e2, [x, y, z])
...@@ -253,7 +253,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -253,7 +253,7 @@ class OpFromGraph(Op, HasInnerGraph):
TYPE_ERR_MSG = ( TYPE_ERR_MSG = (
"L_op/gradient override should be (single or list of)" "L_op/gradient override should be (single or list of)"
"'default' | OpFromGraph | callable | Variable " "None | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got %s" "with NullType or DisconnectedType, got %s"
) )
STYPE_ERR_MSG = ( STYPE_ERR_MSG = (
...@@ -308,9 +308,9 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -308,9 +308,9 @@ class OpFromGraph(Op, HasInnerGraph):
outputs: list[Variable], outputs: list[Variable],
*, *,
inline: bool = False, inline: bool = False,
lop_overrides: str = "default", lop_overrides: Union[Callable, "OpFromGraph", None] = None,
grad_overrides: str = "default", grad_overrides: Union[Callable, "OpFromGraph", None] = None,
rop_overrides: str = "default", rop_overrides: Union[Callable, "OpFromGraph", None] = None,
connection_pattern: list[list[bool]] | None = None, connection_pattern: list[list[bool]] | None = None,
strict: bool = False, strict: bool = False,
name: str | None = None, name: str | None = None,
...@@ -333,10 +333,10 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -333,10 +333,10 @@ class OpFromGraph(Op, HasInnerGraph):
``False`` : will use a pre-compiled function inside. ``False`` : will use a pre-compiled function inside.
grad_overrides grad_overrides
Defaults to ``'default'``. Defaults to ``None``.
This argument is mutually exclusive with ``lop_overrides``. This argument is mutually exclusive with ``lop_overrides``.
``'default'`` : Do not override, use default grad() result ``None`` : Do not override, use default grad() result
`OpFromGraph`: Override with another `OpFromGraph`, should `OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``output_grads`` accept inputs as the same order and types of ``inputs`` and ``output_grads``
...@@ -346,14 +346,14 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -346,14 +346,14 @@ class OpFromGraph(Op, HasInnerGraph):
Each argument is expected to be a list of :class:`Variable `. Each argument is expected to be a list of :class:`Variable `.
Must return list of :class:`Variable `. Must return list of :class:`Variable `.
lop_overrides lop_overrides
Defaults to ``'default'``. Defaults to ``None``.
This argument is mutually exclusive with ``grad_overrides``. This argument is mutually exclusive with ``grad_overrides``.
These options are similar to the ``grad_overrides`` above, but for These options are similar to the ``grad_overrides`` above, but for
the :meth:`Op.L_op` method. the :meth:`Op.L_op` method.
``'default'``: Do not override, use the default :meth:`Op.L_op` result ``None``: Do not override, use the default :meth:`Op.L_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should `OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs``, accept inputs as the same order and types of ``inputs``,
...@@ -373,11 +373,11 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -373,11 +373,11 @@ class OpFromGraph(Op, HasInnerGraph):
a specific input, length of list must be equal to number of inputs. a specific input, length of list must be equal to number of inputs.
rop_overrides rop_overrides
One of ``{'default', OpFromGraph, callable, Variable}``. One of ``{None, OpFromGraph, callable, Variable}``.
Defaults to ``'default'``. Defaults to ``None``.
``'default'``: Do not override, use the default :meth:`Op.R_op` result ``None``: Do not override, use the default :meth:`Op.R_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should `OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``eval_points`` accept inputs as the same order and types of ``inputs`` and ``eval_points``
...@@ -446,19 +446,29 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -446,19 +446,29 @@ class OpFromGraph(Op, HasInnerGraph):
self.input_types = [inp.type for inp in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
for override in (lop_overrides, grad_overrides, rop_overrides):
if override == "default":
raise ValueError(
"'default' is no longer a valid value for overrides. Use None instead."
)
if isinstance(override, Variable):
raise TypeError(
"Variables are no longer valid types for overrides. Return them in a list for each output instead"
)
self.lop_overrides = lop_overrides self.lop_overrides = lop_overrides
self.grad_overrides = grad_overrides self.grad_overrides = grad_overrides
self.rop_overrides = rop_overrides self.rop_overrides = rop_overrides
if lop_overrides != "default": if lop_overrides is not None:
if grad_overrides != "default": if grad_overrides is not None:
raise ValueError( raise ValueError(
"lop_overrides and grad_overrides are mutually exclusive" "lop_overrides and grad_overrides are mutually exclusive"
) )
else: else:
self.set_lop_overrides(lop_overrides) self.set_lop_overrides(lop_overrides)
self._lop_type = "lop" self._lop_type = "lop"
elif grad_overrides != "default": elif grad_overrides is not None:
warnings.warn( warnings.warn(
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future.", "grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future.",
FutureWarning, FutureWarning,
...@@ -466,7 +476,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -466,7 +476,7 @@ class OpFromGraph(Op, HasInnerGraph):
self.set_lop_overrides(grad_overrides) self.set_lop_overrides(grad_overrides)
self._lop_type = "grad" self._lop_type = "grad"
else: else:
self.set_lop_overrides("default") self.set_lop_overrides(None)
self._lop_type = "lop" self._lop_type = "lop"
self.set_rop_overrides(rop_overrides) self.set_rop_overrides(rop_overrides)
...@@ -546,7 +556,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -546,7 +556,7 @@ class OpFromGraph(Op, HasInnerGraph):
callable_args = (local_inputs, output_grads) callable_args = (local_inputs, output_grads)
# we need to convert _lop_op into an OfG instance # we need to convert _lop_op into an OfG instance
if lop_op == "default": if lop_op is None:
gdefaults_l = fn_grad(wrt=local_inputs) gdefaults_l = fn_grad(wrt=local_inputs)
all_grads_l, all_grads_ov_l = zip( all_grads_l, all_grads_ov_l = zip(
*[ *[
...@@ -556,12 +566,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -556,12 +566,6 @@ class OpFromGraph(Op, HasInnerGraph):
) )
all_grads_l = list(all_grads_l) all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l) all_grads_ov_l = list(all_grads_ov_l)
elif isinstance(lop_op, Variable):
if isinstance(lop_op.type, DisconnectedType | NullType):
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [lop_op.type() for _ in range(inp_len)]
else:
raise ValueError(self.STYPE_ERR_MSG % lop_op.type)
elif isinstance(lop_op, list): elif isinstance(lop_op, list):
goverrides_l = lop_op goverrides_l = lop_op
if len(goverrides_l) != inp_len: if len(goverrides_l) != inp_len:
...@@ -571,15 +575,13 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -571,15 +575,13 @@ class OpFromGraph(Op, HasInnerGraph):
) )
# compute non-overriding downsteam grads from upstreams grads # compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore' # it's normal some input may be disconnected, thus the 'ignore'
wrt_l = [ wrt_l = [lin for lin, gov in zip(local_inputs, goverrides_l) if gov is None]
lin for lin, gov in zip(local_inputs, goverrides_l) if gov == "default"
]
gdefaults = iter(fn_grad(wrt=wrt_l) if wrt_l else []) gdefaults = iter(fn_grad(wrt=wrt_l) if wrt_l else [])
# combine overriding gradients # combine overriding gradients
all_grads_l = [] all_grads_l = []
all_grads_ov_l = [] all_grads_ov_l = []
for inp, fn_gov in zip(local_inputs, goverrides_l): for inp, fn_gov in zip(local_inputs, goverrides_l):
if fn_gov == "default": if fn_gov is None:
gnext, gnext_ov = OpFromGraph._filter_grad_var(next(gdefaults), inp) gnext, gnext_ov = OpFromGraph._filter_grad_var(next(gdefaults), inp)
all_grads_l.append(gnext) all_grads_l.append(gnext)
all_grads_ov_l.append(gnext_ov) all_grads_ov_l.append(gnext_ov)
...@@ -652,13 +654,13 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -652,13 +654,13 @@ class OpFromGraph(Op, HasInnerGraph):
fn_rop = partial(Rop, wrt=local_inputs, eval_points=eval_points) fn_rop = partial(Rop, wrt=local_inputs, eval_points=eval_points)
TYPE_ERR_MSG = ( TYPE_ERR_MSG = (
"R_op overrides should be (single or list of)" "R_op overrides should be (single or list of)"
"OpFromGraph | 'default' | None | 0 | callable, got %s" "OpFromGraph, None, a list or a callable, got %s"
) )
STYPE_ERR_MSG = ( STYPE_ERR_MSG = (
"Overriding Variable instance can only have type" "Overriding Variable instance can only have type"
" of DisconnectedType or NullType, got %s" " of DisconnectedType or NullType, got %s"
) )
if rop_op == "default": if rop_op is None:
rdefaults_l = fn_rop(f=local_outputs) rdefaults_l = fn_rop(f=local_outputs)
all_rops_l, all_rops_ov_l = zip( all_rops_l, all_rops_ov_l = zip(
*[ *[
...@@ -668,15 +670,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -668,15 +670,6 @@ class OpFromGraph(Op, HasInnerGraph):
) )
all_rops_l = list(all_rops_l) all_rops_l = list(all_rops_l)
all_rops_ov_l = list(all_rops_ov_l) all_rops_ov_l = list(all_rops_ov_l)
elif isinstance(rop_op, Variable):
if isinstance(rop_op.type, NullType):
all_rops_l = [inp.zeros_like() for inp in local_inputs]
all_rops_ov_l = [rop_op.type() for _ in range(out_len)]
elif isinstance(rop_op.type, DisconnectedType):
all_rops_l = [inp.zeros_like() for inp in local_inputs]
all_rops_ov_l = [None] * out_len
else:
raise ValueError(STYPE_ERR_MSG % rop_op.type)
elif isinstance(rop_op, list): elif isinstance(rop_op, list):
roverrides_l = rop_op roverrides_l = rop_op
if len(roverrides_l) != out_len: if len(roverrides_l) != out_len:
...@@ -686,7 +679,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -686,7 +679,7 @@ class OpFromGraph(Op, HasInnerGraph):
) )
# get outputs that does not have Rop override # get outputs that does not have Rop override
odefaults_l = [ odefaults_l = [
lo for lo, rov in zip(local_outputs, roverrides_l) if rov == "default" lo for lo, rov in zip(local_outputs, roverrides_l) if rov is None
] ]
rdefaults_l = fn_rop(f=odefaults_l) rdefaults_l = fn_rop(f=odefaults_l)
rdefaults = iter(rdefaults_l if odefaults_l else []) rdefaults = iter(rdefaults_l if odefaults_l else [])
...@@ -694,7 +687,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -694,7 +687,7 @@ class OpFromGraph(Op, HasInnerGraph):
all_rops_l = [] all_rops_l = []
all_rops_ov_l = [] all_rops_ov_l = []
for out, fn_rov in zip(local_outputs, roverrides_l): for out, fn_rov in zip(local_outputs, roverrides_l):
if fn_rov == "default": if fn_rov is None:
rnext, rnext_ov = OpFromGraph._filter_rop_var(next(rdefaults), out) rnext, rnext_ov = OpFromGraph._filter_rop_var(next(rdefaults), out)
all_rops_l.append(rnext) all_rops_l.append(rnext)
all_rops_ov_l.append(rnext_ov) all_rops_ov_l.append(rnext_ov)
...@@ -769,7 +762,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -769,7 +762,6 @@ class OpFromGraph(Op, HasInnerGraph):
self._lop_op = grad_overrides self._lop_op = grad_overrides
self._lop_op_is_cached = False self._lop_op_is_cached = False
self._lop_type = "grad" self._lop_type = "grad"
self._lop_is_default = grad_overrides == "default"
def set_lop_overrides(self, lop_overrides): def set_lop_overrides(self, lop_overrides):
""" """
...@@ -780,7 +772,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -780,7 +772,6 @@ class OpFromGraph(Op, HasInnerGraph):
self._lop_op = lop_overrides self._lop_op = lop_overrides
self._lop_op_is_cached = False self._lop_op_is_cached = False
self._lop_type = "lop" self._lop_type = "lop"
self._lop_is_default = lop_overrides == "default"
def set_rop_overrides(self, rop_overrides): def set_rop_overrides(self, rop_overrides):
""" """
...@@ -790,7 +781,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -790,7 +781,6 @@ class OpFromGraph(Op, HasInnerGraph):
""" """
self._rop_op = rop_overrides self._rop_op = rop_overrides
self._rop_op_is_cached = False self._rop_op_is_cached = False
self._rop_is_default = rop_overrides == "default"
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
if not self._lop_op_is_cached: if not self._lop_op_is_cached:
......
...@@ -11,7 +11,7 @@ from pytensor.configdefaults import config ...@@ -11,7 +11,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, Rop, disconnected_type, grad from pytensor.gradient import DisconnectedType, Rop, disconnected_type, grad
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType, null_type
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint from pytensor.printing import debugprint
...@@ -93,6 +93,20 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -93,6 +93,20 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert res.shape == (2, 5) assert res.shape == (2, 5)
assert np.all(180.0 == res) assert np.all(180.0 == res)
def test_overrides_deprecated_api(self):
inp = scalar("x")
out = inp + 1
for kwarg in ("lop_overrides", "grad_overrides", "rop_overrides"):
with pytest.raises(
ValueError, match="'default' is no longer a valid value for overrides"
):
OpFromGraph([inp], [out], **{kwarg: "default"})
with pytest.raises(
TypeError, match="Variables are no longer valid types for overrides"
):
OpFromGraph([inp], [out], **{kwarg: null_type()})
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
) )
...@@ -211,9 +225,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -211,9 +225,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
w, b = vectors("wb") w, b = vectors("wb")
# we make the 3rd gradient default (no override) # we make the 3rd gradient default (no override)
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"): with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
op_linear = cls_ofg( op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, None])
[x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"]
)
xx, ww, bb = vector("xx"), vector("yy"), vector("bb") xx, ww, bb = vector("xx"), vector("yy"), vector("bb")
zz = pt_sum(op_linear(xx, ww, bb)) zz = pt_sum(op_linear(xx, ww, bb))
dx, dw, db = grad(zz, [xx, ww, bb]) dx, dw, db = grad(zz, [xx, ww, bb])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论