提交 b5a64c77 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Compute pushforward via double application of pullback

Also fixes bug in Scan L_op and Max R_op Co-authored-by: 's avatarAdrian Seyboldt <aseyboldt@users.noreply.github.com>
上级 84c78027
......@@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`.
the outputs) back to their corresponding shapes and return them as the
output of the :meth:`Op.R_op` method.
:ref:`List of op with r op support <R_op_list>`.
.. _libdoc_gradient:
===========================================
:mod:`gradient` -- Symbolic Differentiation
===========================================
.. module:: gradient
:platform: Unix, Windows
:synopsis: low-level automatic differentiation
.. moduleauthor:: LISA
.. testsetup:: *
from pytensor.gradient import *
Symbolic gradient is usually computed from :func:`gradient.grad`, which offers a
more convenient syntax for the common case of wanting the gradient of some
scalar cost with respect to some input expressions. The :func:`grad_sources_inputs`
function does the underlying work, and is more flexible, but is also more
awkward to use when :func:`gradient.grad` can do the job.
Gradient related functions
==========================
.. automodule:: pytensor.gradient
:members:
.. _R_op_list:
List of Implemented R op
========================
See the :ref:`gradient tutorial <tutcomputinggrads>` for the R op documentation.
list of ops that support R-op:
* with test
* SpecifyShape
* MaxAndArgmax
* Subtensor
* IncSubtensor set_subtensor too
* Alloc
* Dot
* Elemwise
* Sum
* Softmax
* Shape
* Join
* Rebroadcast
* Reshape
* DimShuffle
* Scan [In tests/scan/test_basic.test_rop]
* without test
* Split
* ARange
* ScalarFromTensor
* AdvancedSubtensor1
* AdvancedIncSubtensor1
* AdvancedIncSubtensor
Partial list of ops without support for R-op:
* All sparse ops
* All linear algebra ops.
* PermuteRowElements
* AdvancedSubtensor
* TensorDot
* Outer
* Prod
* MulwithoutZeros
* ProdWithoutZeros
* CAReduce(for max,... done for MaxAndArgmax op)
* MaxAndArgmax(only for matrix on axis 0 or 1)
......@@ -1791,5 +1791,3 @@ Gradient / Differentiation
:members: grad
:noindex:
See the :ref:`gradient <libdoc_gradient>` page for complete documentation
of the gradient module.
......@@ -86,9 +86,7 @@ of symbolic differentiation).
``i`` of the output list is the gradient of the first argument of
`pt.grad` with respect to the ``i``-th element of the list given as second argument.
The first argument of `pt.grad` has to be a scalar (a tensor
of size 1). For more information on the semantics of the arguments of
`pt.grad` and details about the implementation, see
:ref:`this<libdoc_gradient>` section of the library.
of size 1).
Additional information on the inner workings of differentiation may also be
found in the more advanced tutorial :ref:`Extending PyTensor<extending>`.
......@@ -204,7 +202,21 @@ you need to do something similar to this:
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])
:ref:`List <R_op_list>` of Op that implement Rop.
By default, the R-operator is implemented as a double application of the L_operator
(see `reference <https://j-towns.github.io/2017/06/12/A-new-trick.html>`_).
In most cases this should be as performant as a specialized implementation of the R-operator.
However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.
When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.
>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True)
>>> f = pytensor.function([W, V, x], JV)
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])
L-operator
----------
......@@ -234,7 +246,6 @@ array([[ 0., 0.],
as the input parameter, while the result of the R-operator has a shape similar
to that of the output.
:ref:`List of op with r op support <R_op_list>`.
Hessian times a Vector
======================
......
......@@ -340,6 +340,12 @@ class OpFromGraph(Op, HasInnerGraph):
``None``, this will be used as the connection_pattern for this
:class:`Op`.
.. warning::
rop overrides is ignored when `pytensor.gradient.Rop` is called with
`use_op_rop_implementation=False` (default). In this case the Lop
is used twice to obtain a mathematically equivalent Rop.
strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
......@@ -641,7 +647,12 @@ class OpFromGraph(Op, HasInnerGraph):
return rop_overrides
eval_points = [inp_t() for inp_t in self.input_types]
fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points)
fn_rop = partial(
Rop,
wrt=inner_inputs,
eval_points=eval_points,
use_op_rop_implementation=True,
)
callable_args = (inner_inputs, eval_points)
if rop_overrides is None:
......
......@@ -142,13 +142,50 @@ class DisconnectedType(Type):
disconnected_type = DisconnectedType()
def Rop(
f: Variable | Sequence[Variable],
wrt: Variable | Sequence[Variable],
eval_points: Variable | Sequence[Variable],
def pushforward_through_pullback(
outputs: Sequence[Variable],
inputs: Sequence[Variable],
tangents: Sequence[Variable],
disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise",
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
) -> Variable | None | Sequence[Variable | None]:
) -> Sequence[Variable | None]:
"""Compute the pushforward (Rop) through two applications of a pullback (Lop) operation.
References
----------
.. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017.
Available: https://j-towns.github.io/2017/06/12/A-new-trick.html
"""
# Cotangents are just auxiliary variables that should be pruned from the final graph,
# but that would require a graph rewrite before the user tries to compile a pytensor function.
# To avoid trouble we use .zeros_like() instead of .type(), which does not create a new root variable.
cotangents = [out.zeros_like(dtype=config.floatX) for out in outputs] # type: ignore
input_cotangents = Lop(
f=outputs,
wrt=inputs,
eval_points=cotangents,
disconnected_inputs=disconnected_outputs,
return_disconnected="zero",
)
return Lop(
f=input_cotangents, # type: ignore
wrt=cotangents,
eval_points=tangents,
disconnected_inputs="ignore",
return_disconnected=return_disconnected,
)
def _rop_legacy(
f: Sequence[Variable],
wrt: Sequence[Variable],
eval_points: Sequence[Variable],
disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise",
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
) -> Sequence[Variable | None]:
"""Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`.
Mathematically this stands for the Jacobian of `f` right multiplied by the
......@@ -190,38 +227,6 @@ def Rop(
If `f` is a list/tuple, then return a list/tuple with the results.
"""
if not isinstance(wrt, list | tuple):
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
else:
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
if not isinstance(eval_points, list | tuple):
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
else:
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
if not isinstance(f, list | tuple):
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
else:
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
if len(_wrt) != len(_eval_points):
raise ValueError("`wrt` must be the same length as `eval_points`.")
# Check that each element of wrt corresponds to an element
# of eval_points with the same dimensionality.
for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)):
try:
if wrt_elem.type.ndim != eval_point.type.ndim:
raise ValueError(
f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: "
f"{wrt_elem.type.ndim} and {eval_point.type.ndim}"
)
except AttributeError:
# wrt_elem and eval_point don't always have ndim like random type
# Tensor, Sparse have the ndim attribute
pass
seen_nodes: dict[Apply, Sequence[Variable]] = {}
def _traverse(node):
......@@ -237,8 +242,8 @@ def Rop(
# inputs of the node
local_eval_points = []
for inp in inputs:
if inp in _wrt:
local_eval_points.append(_eval_points[_wrt.index(inp)])
if inp in wrt:
local_eval_points.append(eval_points[wrt.index(inp)])
elif inp.owner is None:
try:
local_eval_points.append(inp.zeros_like())
......@@ -292,13 +297,13 @@ def Rop(
# end _traverse
# Populate the dictionary
for out in _f:
for out in f:
_traverse(out.owner)
rval: list[Variable | None] = []
for out in _f:
if out in _wrt:
rval.append(_eval_points[_wrt.index(out)])
for out in f:
if out in wrt:
rval.append(eval_points[wrt.index(out)])
elif (
seen_nodes.get(out.owner, None) is None
or seen_nodes[out.owner][out.owner.outputs.index(out)] is None
......@@ -337,6 +342,116 @@ def Rop(
else:
rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)])
return rval
def Rop(
f: Variable | Sequence[Variable],
wrt: Variable | Sequence[Variable],
eval_points: Variable | Sequence[Variable],
disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise",
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
use_op_rop_implementation: bool = False,
) -> Variable | None | Sequence[Variable | None]:
"""Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`.
Mathematically this stands for the Jacobian of `f` right multiplied by the
`eval_points`.
By default, the R-operator is implemented as a double application of the L_operator [1]_.
In most cases this should be as performant as a specialized implementation of the R-operator.
However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.
When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.
Parameters
----------
f
The outputs of the computational graph to which the R-operator is
applied.
wrt
Variables for which the R-operator of `f` is computed.
eval_points
Points at which to evaluate each of the variables in `wrt`.
disconnected_outputs
Defines the behaviour if some of the variables in `f`
have no dependency on any of the variable in `wrt` (or if
all links are non-differentiable). The possible values are:
- ``'ignore'``: considers that the gradient on these parameters is zero.
- ``'warn'``: consider the gradient zero, and print a warning.
- ``'raise'``: raise `DisconnectedInputError`.
return_disconnected
- ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
``wrt[i].zeros_like()``.
- ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
``None``
- ``'disconnected'`` : returns variables of type `DisconnectedType`
use_op_lop_implementation: bool, default=True
If `True`, we obtain Rop via double application of Lop.
If `False`, the legacy Rop implementation is used. The number of graphs that support this form
is much more restricted, and the generated graphs may be less optimized.
Returns
-------
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
A symbolic expression such obeying
``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all
coordinates of the tensor elements.
If `f` is a list/tuple, then return a list/tuple with the results.
References
----------
.. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017.
Available: https://j-towns.github.io/2017/06/12/A-new-trick.html
"""
if not isinstance(wrt, list | tuple):
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
else:
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
if not isinstance(eval_points, list | tuple):
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
else:
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
if not isinstance(f, list | tuple):
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
else:
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
if len(_wrt) != len(_eval_points):
raise ValueError("`wrt` must be the same length as `eval_points`.")
# Check that each element of wrt corresponds to an element
# of eval_points with the same dimensionality.
for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)):
try:
if wrt_elem.type.ndim != eval_point.type.ndim:
raise ValueError(
f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: "
f"{wrt_elem.type.ndim} and {eval_point.type.ndim}"
)
except AttributeError:
# wrt_elem and eval_point don't always have ndim like random type
# Tensor, Sparse have the ndim attribute
pass
if use_op_rop_implementation:
rval = _rop_legacy(
_f, _wrt, _eval_points, disconnected_outputs, return_disconnected
)
else:
rval = pushforward_through_pullback(
_f, _wrt, _eval_points, disconnected_outputs, return_disconnected
)
using_list = isinstance(f, list)
using_tuple = isinstance(f, tuple)
return as_list_or_tuple(using_list, using_tuple, rval)
......@@ -348,6 +463,7 @@ def Lop(
eval_points: Variable | Sequence[Variable],
consider_constant: Sequence[Variable] | None = None,
disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise",
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
) -> Variable | None | Sequence[Variable | None]:
"""Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`.
......@@ -404,6 +520,7 @@ def Lop(
consider_constant=consider_constant,
wrt=_wrt,
disconnected_inputs=disconnected_inputs,
return_disconnected=return_disconnected,
)
using_list = isinstance(wrt, list)
......
......@@ -3165,7 +3165,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rop_self_outputs = self_outputs
if info.n_shared_outs > 0:
rop_self_outputs = rop_self_outputs[: -info.n_shared_outs]
rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
rop_outs = Rop(
rop_self_outputs,
rop_of_inputs,
inner_eval_points,
use_op_rop_implementation=True,
)
if not isinstance(rop_outs, list | tuple):
rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan
......
......@@ -306,7 +306,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
)
def test_rop(self, cls_ofg):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_rop(self, cls_ofg, use_op_rop_implementation):
a = vector()
M = matrix()
b = dot(a, M)
......@@ -315,7 +316,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
W = matrix()
y = op_matmul(x, W)
du = vector()
dv = Rop(y, x, du)
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv)
xval = np.random.random((16,)).astype(config.floatX)
Wval = np.random.random((16, 16)).astype(config.floatX)
......@@ -324,7 +325,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
dvval2 = fn(xval, Wval, duval)
np.testing.assert_array_almost_equal(dvval2, dvval, 4)
def test_rop_multiple_outputs(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_rop_multiple_outputs(self, use_op_rop_implementation):
a = vector()
M = matrix()
b = dot(a, M)
......@@ -339,21 +341,21 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
duval = np.random.random((16,)).astype(config.floatX)
y = op_matmul(x, W)[0]
dv = Rop(y, x, du)
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval)
expected_dvval = np.dot(duval, Wval)
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
y = op_matmul(x, W)[1]
dv = Rop(y, x, du)
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval)
expected_dvval = -np.dot(duval, Wval)
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
y = pt.add(*op_matmul(x, W))
dv = Rop(y, x, du)
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval)
expected_dvval = np.zeros_like(np.dot(duval, Wval))
......@@ -362,7 +364,16 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
)
def test_rop_override(self, cls_ofg):
@pytest.mark.parametrize(
"use_op_rop_implementation",
[
True,
pytest.param(
False, marks=pytest.mark.xfail(reason="Custom ROp is ignored")
),
],
)
def test_rop_override(self, cls_ofg, use_op_rop_implementation):
x, y = vectors("xy")
def ro(inps, epts):
......@@ -380,7 +391,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
du, dv = vector("du"), vector("dv")
for op in [op_mul, op_mul2]:
zz = op_mul(xx, yy)
dw = Rop(zz, [xx, yy], [du, dv])
dw = Rop(
zz,
[xx, yy],
[du, dv],
use_op_rop_implementation=use_op_rop_implementation,
)
fn = function([xx, yy, du, dv], dw)
vals = np.random.random((4, 32)).astype(config.floatX)
dwval = fn(*vals)
......
......@@ -1922,7 +1922,8 @@ class TestScan:
fgrad = function([], g_sh)
assert fgrad() == 1
def test_R_op(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_R_op(self, use_op_rop_implementation):
seed = utt.fetch_seed()
rng = np.random.default_rng(seed)
floatX = config.floatX
......@@ -1957,9 +1958,9 @@ class TestScan:
eh0 = vector("eh0")
eW = matrix("eW")
nwo_u = Rop(o, _u, eu)
nwo_h0 = Rop(o, _h0, eh0)
nwo_W = Rop(o, _W, eW)
nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation)
nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation)
nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation)
fn_rop = function(
[u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore"
)
......@@ -1997,7 +1998,8 @@ class TestScan:
np.testing.assert_allclose(vnW, tnW, atol=1e-6)
@pytest.mark.slow
def test_R_op_2(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_R_op_2(self, use_op_rop_implementation):
seed = utt.fetch_seed()
rng = np.random.default_rng(seed)
floatX = config.floatX
......@@ -2040,9 +2042,9 @@ class TestScan:
eh0 = vector("eh0")
eW = matrix("eW")
nwo_u = Rop(o, _u, eu)
nwo_h0 = Rop(o, _h0, eh0)
nwo_W = Rop(o, _W, eW)
nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation)
nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation)
nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation)
fn_rop = function(
[u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W, o], on_unused_input="ignore"
)
......@@ -2078,7 +2080,8 @@ class TestScan:
np.testing.assert_allclose(vnh0, tnh0, atol=1e-6)
np.testing.assert_allclose(vnW, tnW, atol=2e-6)
def test_R_op_mitmot(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_R_op_mitmot(self, use_op_rop_implementation):
# this test is a copy paste from the script given by Justin Bayer to
# reproduce this bug
# We have 2 parameter groups with the following shapes.
......@@ -2126,7 +2129,12 @@ class TestScan:
p = dvector()
# TODO: We should test something about the Rop!
Rop(d_cost_wrt_pars, pars, p)
Rop(
d_cost_wrt_pars,
pars,
p,
use_op_rop_implementation=use_op_rop_implementation,
)
def test_second_derivative_disconnected_cost_with_mit_mot(self):
# This test is a regression test for a bug that was revealed
......
......@@ -49,9 +49,12 @@ def test_matrix_inverse_rop_lop():
v = vector("v")
y = MatrixInverse()(mx).sum(axis=0)
yv = pytensor.gradient.Rop(y, mx, mv)
yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True)
rop_f = function([mx, mv], yv)
yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False)
rop_via_lop_f = function([mx, mv], yv_via_lop)
sy, _ = pytensor.scan(
lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(),
sequences=pt.arange(y.shape[0]),
......@@ -65,10 +68,14 @@ def test_matrix_inverse_rop_lop():
v_ref = scan_f(vx, vv)
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol)
with pytest.raises(ValueError):
pytensor.gradient.Rop(
pytensor.clone_replace(y, replace={mx: break_op(mx)}), mx, mv
pytensor.clone_replace(y, replace={mx: break_op(mx)}),
mx,
mv,
use_op_rop_implementation=True,
)
vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX)
......
......@@ -88,7 +88,7 @@ class RopLopChecker:
test that an error is raised.
"""
with pytest.raises(ValueError):
Rop(y, x, v)
Rop(y, x, v, use_op_rop_implementation=True)
def check_mat_rop_lop(self, y, out_shape):
"""
......@@ -116,8 +116,14 @@ class RopLopChecker:
vv = np.asarray(
self.rng.uniform(size=self.mat_in_shape), pytensor.config.floatX
)
yv = Rop(y, self.mx, self.mv)
yv = Rop(y, self.mx, self.mv, use_op_rop_implementation=True)
rop_f = function([self.mx, self.mv], yv, on_unused_input="ignore")
yv_through_lop = Rop(y, self.mx, self.mv, use_op_rop_implementation=False)
rop_through_lop_f = function(
[self.mx, self.mv], yv_through_lop, on_unused_input="ignore"
)
sy, _ = pytensor.scan(
lambda i, y, x, v: (grad(y[i], x) * v).sum(),
sequences=pt.arange(y.shape[0]),
......@@ -127,6 +133,7 @@ class RopLopChecker:
v_ref = scan_f(vx, vv)
np.testing.assert_allclose(rop_f(vx, vv), v_ref)
np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref)
self.check_nondiff_rop(
pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}),
......@@ -156,8 +163,14 @@ class RopLopChecker:
vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
yv = Rop(y, self.x, self.v)
yv = Rop(y, self.x, self.v, use_op_rop_implementation=True)
rop_f = function([self.x, self.v], yv, on_unused_input="ignore")
yv_through_lop = Rop(y, self.x, self.v, use_op_rop_implementation=False)
rop_through_lop_f = function(
[self.x, self.v], yv_through_lop, on_unused_input="ignore"
)
J, _ = pytensor.scan(
lambda i, y, x: grad(y[i], x),
sequences=pt.arange(y.shape[0]),
......@@ -168,6 +181,7 @@ class RopLopChecker:
v_ref = scan_f(vx, vv)
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref, rtol=rtol)
if check_nondiff_rop:
self.check_nondiff_rop(
......@@ -255,12 +269,12 @@ class TestRopLop(RopLopChecker):
insh = self.in_shape[0]
vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX)
W = pytensor.shared(vW)
# check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
# check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths
# See: test_Rop_partially_differentiable_paths
self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False)
def test_elemwise0(self):
# check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
# check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths
# See: test_Rop_partially_differentiable_paths
self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False)
......@@ -294,11 +308,18 @@ class TestRopLop(RopLopChecker):
self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0],
)
def test_invalid_input(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_invalid_input(self, use_op_rop_implementation):
with pytest.raises(ValueError):
Rop(0.0, [matrix()], [vector()])
Rop(
0.0,
[matrix()],
[vector()],
use_op_rop_implementation=use_op_rop_implementation,
)
def test_multiple_outputs(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_multiple_outputs(self, use_op_rop_implementation):
m = matrix("m")
v = vector("v")
m_ = matrix("m_")
......@@ -309,10 +330,20 @@ class TestRopLop(RopLopChecker):
m_val = self.rng.uniform(size=(3, 7)).astype(pytensor.config.floatX)
v_val = self.rng.uniform(size=(7,)).astype(pytensor.config.floatX)
rop_out1 = Rop([m, v, m + v], [m, v], [m_, v_])
rop_out1 = Rop(
[m, v, m + v],
[m, v],
[m_, v_],
use_op_rop_implementation=use_op_rop_implementation,
)
assert isinstance(rop_out1, list)
assert len(rop_out1) == 3
rop_out2 = Rop((m, v, m + v), [m, v], [m_, v_])
rop_out2 = Rop(
(m, v, m + v),
[m, v],
[m_, v_],
use_op_rop_implementation=use_op_rop_implementation,
)
assert isinstance(rop_out2, tuple)
assert len(rop_out2) == 3
......@@ -322,8 +353,11 @@ class TestRopLop(RopLopChecker):
f = pytensor.function([m, v, m_, v_], all_outs)
f(mval, vval, m_val, v_val)
@pytest.mark.xfail()
def test_Rop_partially_differentiable_paths(self):
@pytest.mark.parametrize(
"use_op_rop_implementation",
[pytest.param(True, marks=pytest.mark.xfail()), False],
)
def test_Rop_partially_differentiable_paths(self, use_op_rop_implementation):
# This test refers to a bug reported by Jeremiah Lowin on 18th Oct
# 2013. The bug consists when through a dot operation there is only
# one differentiable path (i.e. there is no gradient wrt to one of
......@@ -336,7 +370,12 @@ class TestRopLop(RopLopChecker):
grad(d, v),
v,
v,
disconnected_outputs="raise",
use_op_rop_implementation=use_op_rop_implementation,
# 2025: This is a tricky case, the gradient of the gradient does not depend on v
# although v still exists in the graph inside a `Second` operator.
# The original test was checking that Rop wouldn't raise an error, but Lop does.
# Since the correct behavior is ambiguous, I let both implementations off the hook.
disconnected_outputs="raise" if use_op_rop_implementation else "ignore",
)
# 2025: Here is an unambiguous test for the original commented issue:
......@@ -348,10 +387,11 @@ class TestRopLop(RopLopChecker):
out,
[x],
[x.type()],
use_op_rop_implementation=use_op_rop_implementation,
disconnected_outputs="raise",
)
# More extensive testing shows that the Rop implementation FAILS to raise when
# More extensive testing shows that the legacy Rop implementation FAILS to raise when
# the cost is linked through strictly non-differentiable paths.
# This is not Dot specific, we would observe the same with any operation where the gradient
# with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...)
......@@ -361,6 +401,7 @@ class TestRopLop(RopLopChecker):
out,
[x],
[x.type()],
use_op_rop_implementation=use_op_rop_implementation,
disconnected_outputs="raise",
)
......@@ -371,5 +412,6 @@ class TestRopLop(RopLopChecker):
out,
[x],
[x.type()],
use_op_rop_implementation=use_op_rop_implementation,
disconnected_outputs="raise",
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论