提交 b5a64c77 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Compute pushforward via double application of pullback

Also fixes bug in Scan L_op and Max R_op Co-authored-by: 's avatarAdrian Seyboldt <aseyboldt@users.noreply.github.com>
上级 84c78027
......@@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`.
the outputs) back to their corresponding shapes and return them as the
output of the :meth:`Op.R_op` method.
:ref:`List of op with r op support <R_op_list>`.
.. _libdoc_gradient:
===========================================
:mod:`gradient` -- Symbolic Differentiation
===========================================
.. module:: gradient
:platform: Unix, Windows
:synopsis: low-level automatic differentiation
.. moduleauthor:: LISA
.. testsetup:: *
from pytensor.gradient import *
Symbolic gradient is usually computed from :func:`gradient.grad`, which offers a
more convenient syntax for the common case of wanting the gradient of some
scalar cost with respect to some input expressions. The :func:`grad_sources_inputs`
function does the underlying work, and is more flexible, but is also more
awkward to use when :func:`gradient.grad` can do the job.
Gradient related functions
==========================
.. automodule:: pytensor.gradient
:members:
.. _R_op_list:
List of Implemented R op
========================
See the :ref:`gradient tutorial <tutcomputinggrads>` for the R op documentation.
list of ops that support R-op:
* with test
* SpecifyShape
* MaxAndArgmax
* Subtensor
* IncSubtensor set_subtensor too
* Alloc
* Dot
* Elemwise
* Sum
* Softmax
* Shape
* Join
* Rebroadcast
* Reshape
* DimShuffle
* Scan [In tests/scan/test_basic.test_rop]
* without test
* Split
* ARange
* ScalarFromTensor
* AdvancedSubtensor1
* AdvancedIncSubtensor1
* AdvancedIncSubtensor
Partial list of ops without support for R-op:
* All sparse ops
* All linear algebra ops.
* PermuteRowElements
* AdvancedSubtensor
* TensorDot
* Outer
* Prod
* MulwithoutZeros
* ProdWithoutZeros
* CAReduce(for max,... done for MaxAndArgmax op)
* MaxAndArgmax(only for matrix on axis 0 or 1)
......@@ -1791,5 +1791,3 @@ Gradient / Differentiation
:members: grad
:noindex:
See the :ref:`gradient <libdoc_gradient>` page for complete documentation
of the gradient module.
......@@ -86,9 +86,7 @@ of symbolic differentiation).
``i`` of the output list is the gradient of the first argument of
`pt.grad` with respect to the ``i``-th element of the list given as second argument.
The first argument of `pt.grad` has to be a scalar (a tensor
of size 1). For more information on the semantics of the arguments of
`pt.grad` and details about the implementation, see
:ref:`this<libdoc_gradient>` section of the library.
of size 1).
Additional information on the inner workings of differentiation may also be
found in the more advanced tutorial :ref:`Extending PyTensor<extending>`.
......@@ -204,7 +202,21 @@ you need to do something similar to this:
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])
:ref:`List <R_op_list>` of Op that implement Rop.
By default, the R-operator is implemented as a double application of the L_operator
(see `reference <https://j-towns.github.io/2017/06/12/A-new-trick.html>`_).
In most cases this should be as performant as a specialized implementation of the R-operator.
However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.
When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.
>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True)
>>> f = pytensor.function([W, V, x], JV)
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])
L-operator
----------
......@@ -234,7 +246,6 @@ array([[ 0., 0.],
as the input parameter, while the result of the R-operator has a shape similar
to that of the output.
:ref:`List of op with r op support <R_op_list>`.
Hessian times a Vector
======================
......
......@@ -340,6 +340,12 @@ class OpFromGraph(Op, HasInnerGraph):
``None``, this will be used as the connection_pattern for this
:class:`Op`.
.. warning::
rop overrides is ignored when `pytensor.gradient.Rop` is called with
`use_op_rop_implementation=False` (default). In this case the Lop
is used twice to obtain a mathematically equivalent Rop.
strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
......@@ -641,7 +647,12 @@ class OpFromGraph(Op, HasInnerGraph):
return rop_overrides
eval_points = [inp_t() for inp_t in self.input_types]
fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points)
fn_rop = partial(
Rop,
wrt=inner_inputs,
eval_points=eval_points,
use_op_rop_implementation=True,
)
callable_args = (inner_inputs, eval_points)
if rop_overrides is None:
......
差异被折叠。
......@@ -3165,7 +3165,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rop_self_outputs = self_outputs
if info.n_shared_outs > 0:
rop_self_outputs = rop_self_outputs[: -info.n_shared_outs]
rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
rop_outs = Rop(
rop_self_outputs,
rop_of_inputs,
inner_eval_points,
use_op_rop_implementation=True,
)
if not isinstance(rop_outs, list | tuple):
rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan
......
......@@ -306,7 +306,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
)
def test_rop(self, cls_ofg):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_rop(self, cls_ofg, use_op_rop_implementation):
a = vector()
M = matrix()
b = dot(a, M)
......@@ -315,7 +316,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
W = matrix()
y = op_matmul(x, W)
du = vector()
dv = Rop(y, x, du)
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv)
xval = np.random.random((16,)).astype(config.floatX)
Wval = np.random.random((16, 16)).astype(config.floatX)
......@@ -324,7 +325,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
dvval2 = fn(xval, Wval, duval)
np.testing.assert_array_almost_equal(dvval2, dvval, 4)
def test_rop_multiple_outputs(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_rop_multiple_outputs(self, use_op_rop_implementation):
a = vector()
M = matrix()
b = dot(a, M)
......@@ -339,21 +341,21 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
duval = np.random.random((16,)).astype(config.floatX)
y = op_matmul(x, W)[0]
dv = Rop(y, x, du)
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval)
expected_dvval = np.dot(duval, Wval)
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
y = op_matmul(x, W)[1]
dv = Rop(y, x, du)
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval)
expected_dvval = -np.dot(duval, Wval)
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
y = pt.add(*op_matmul(x, W))
dv = Rop(y, x, du)
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval)
expected_dvval = np.zeros_like(np.dot(duval, Wval))
......@@ -362,7 +364,16 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
)
def test_rop_override(self, cls_ofg):
@pytest.mark.parametrize(
"use_op_rop_implementation",
[
True,
pytest.param(
False, marks=pytest.mark.xfail(reason="Custom ROp is ignored")
),
],
)
def test_rop_override(self, cls_ofg, use_op_rop_implementation):
x, y = vectors("xy")
def ro(inps, epts):
......@@ -380,7 +391,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
du, dv = vector("du"), vector("dv")
for op in [op_mul, op_mul2]:
zz = op_mul(xx, yy)
dw = Rop(zz, [xx, yy], [du, dv])
dw = Rop(
zz,
[xx, yy],
[du, dv],
use_op_rop_implementation=use_op_rop_implementation,
)
fn = function([xx, yy, du, dv], dw)
vals = np.random.random((4, 32)).astype(config.floatX)
dwval = fn(*vals)
......
......@@ -1922,7 +1922,8 @@ class TestScan:
fgrad = function([], g_sh)
assert fgrad() == 1
def test_R_op(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_R_op(self, use_op_rop_implementation):
seed = utt.fetch_seed()
rng = np.random.default_rng(seed)
floatX = config.floatX
......@@ -1957,9 +1958,9 @@ class TestScan:
eh0 = vector("eh0")
eW = matrix("eW")
nwo_u = Rop(o, _u, eu)
nwo_h0 = Rop(o, _h0, eh0)
nwo_W = Rop(o, _W, eW)
nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation)
nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation)
nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation)
fn_rop = function(
[u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore"
)
......@@ -1997,7 +1998,8 @@ class TestScan:
np.testing.assert_allclose(vnW, tnW, atol=1e-6)
@pytest.mark.slow
def test_R_op_2(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_R_op_2(self, use_op_rop_implementation):
seed = utt.fetch_seed()
rng = np.random.default_rng(seed)
floatX = config.floatX
......@@ -2040,9 +2042,9 @@ class TestScan:
eh0 = vector("eh0")
eW = matrix("eW")
nwo_u = Rop(o, _u, eu)
nwo_h0 = Rop(o, _h0, eh0)
nwo_W = Rop(o, _W, eW)
nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation)
nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation)
nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation)
fn_rop = function(
[u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W, o], on_unused_input="ignore"
)
......@@ -2078,7 +2080,8 @@ class TestScan:
np.testing.assert_allclose(vnh0, tnh0, atol=1e-6)
np.testing.assert_allclose(vnW, tnW, atol=2e-6)
def test_R_op_mitmot(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_R_op_mitmot(self, use_op_rop_implementation):
# this test is a copy paste from the script given by Justin Bayer to
# reproduce this bug
# We have 2 parameter groups with the following shapes.
......@@ -2126,7 +2129,12 @@ class TestScan:
p = dvector()
# TODO: We should test something about the Rop!
Rop(d_cost_wrt_pars, pars, p)
Rop(
d_cost_wrt_pars,
pars,
p,
use_op_rop_implementation=use_op_rop_implementation,
)
def test_second_derivative_disconnected_cost_with_mit_mot(self):
# This test is a regression test for a bug that was revealed
......
......@@ -49,9 +49,12 @@ def test_matrix_inverse_rop_lop():
v = vector("v")
y = MatrixInverse()(mx).sum(axis=0)
yv = pytensor.gradient.Rop(y, mx, mv)
yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True)
rop_f = function([mx, mv], yv)
yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False)
rop_via_lop_f = function([mx, mv], yv_via_lop)
sy, _ = pytensor.scan(
lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(),
sequences=pt.arange(y.shape[0]),
......@@ -65,10 +68,14 @@ def test_matrix_inverse_rop_lop():
v_ref = scan_f(vx, vv)
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol)
with pytest.raises(ValueError):
pytensor.gradient.Rop(
pytensor.clone_replace(y, replace={mx: break_op(mx)}), mx, mv
pytensor.clone_replace(y, replace={mx: break_op(mx)}),
mx,
mv,
use_op_rop_implementation=True,
)
vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX)
......
......@@ -88,7 +88,7 @@ class RopLopChecker:
test that an error is raised.
"""
with pytest.raises(ValueError):
Rop(y, x, v)
Rop(y, x, v, use_op_rop_implementation=True)
def check_mat_rop_lop(self, y, out_shape):
"""
......@@ -116,8 +116,14 @@ class RopLopChecker:
vv = np.asarray(
self.rng.uniform(size=self.mat_in_shape), pytensor.config.floatX
)
yv = Rop(y, self.mx, self.mv)
yv = Rop(y, self.mx, self.mv, use_op_rop_implementation=True)
rop_f = function([self.mx, self.mv], yv, on_unused_input="ignore")
yv_through_lop = Rop(y, self.mx, self.mv, use_op_rop_implementation=False)
rop_through_lop_f = function(
[self.mx, self.mv], yv_through_lop, on_unused_input="ignore"
)
sy, _ = pytensor.scan(
lambda i, y, x, v: (grad(y[i], x) * v).sum(),
sequences=pt.arange(y.shape[0]),
......@@ -127,6 +133,7 @@ class RopLopChecker:
v_ref = scan_f(vx, vv)
np.testing.assert_allclose(rop_f(vx, vv), v_ref)
np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref)
self.check_nondiff_rop(
pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}),
......@@ -156,8 +163,14 @@ class RopLopChecker:
vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
yv = Rop(y, self.x, self.v)
yv = Rop(y, self.x, self.v, use_op_rop_implementation=True)
rop_f = function([self.x, self.v], yv, on_unused_input="ignore")
yv_through_lop = Rop(y, self.x, self.v, use_op_rop_implementation=False)
rop_through_lop_f = function(
[self.x, self.v], yv_through_lop, on_unused_input="ignore"
)
J, _ = pytensor.scan(
lambda i, y, x: grad(y[i], x),
sequences=pt.arange(y.shape[0]),
......@@ -168,6 +181,7 @@ class RopLopChecker:
v_ref = scan_f(vx, vv)
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref, rtol=rtol)
if check_nondiff_rop:
self.check_nondiff_rop(
......@@ -255,12 +269,12 @@ class TestRopLop(RopLopChecker):
insh = self.in_shape[0]
vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX)
W = pytensor.shared(vW)
# check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
# check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths
# See: test_Rop_partially_differentiable_paths
self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False)
def test_elemwise0(self):
# check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
# check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths
# See: test_Rop_partially_differentiable_paths
self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False)
......@@ -294,11 +308,18 @@ class TestRopLop(RopLopChecker):
self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0],
)
def test_invalid_input(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_invalid_input(self, use_op_rop_implementation):
with pytest.raises(ValueError):
Rop(0.0, [matrix()], [vector()])
Rop(
0.0,
[matrix()],
[vector()],
use_op_rop_implementation=use_op_rop_implementation,
)
def test_multiple_outputs(self):
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_multiple_outputs(self, use_op_rop_implementation):
m = matrix("m")
v = vector("v")
m_ = matrix("m_")
......@@ -309,10 +330,20 @@ class TestRopLop(RopLopChecker):
m_val = self.rng.uniform(size=(3, 7)).astype(pytensor.config.floatX)
v_val = self.rng.uniform(size=(7,)).astype(pytensor.config.floatX)
rop_out1 = Rop([m, v, m + v], [m, v], [m_, v_])
rop_out1 = Rop(
[m, v, m + v],
[m, v],
[m_, v_],
use_op_rop_implementation=use_op_rop_implementation,
)
assert isinstance(rop_out1, list)
assert len(rop_out1) == 3
rop_out2 = Rop((m, v, m + v), [m, v], [m_, v_])
rop_out2 = Rop(
(m, v, m + v),
[m, v],
[m_, v_],
use_op_rop_implementation=use_op_rop_implementation,
)
assert isinstance(rop_out2, tuple)
assert len(rop_out2) == 3
......@@ -322,8 +353,11 @@ class TestRopLop(RopLopChecker):
f = pytensor.function([m, v, m_, v_], all_outs)
f(mval, vval, m_val, v_val)
@pytest.mark.xfail()
def test_Rop_partially_differentiable_paths(self):
@pytest.mark.parametrize(
"use_op_rop_implementation",
[pytest.param(True, marks=pytest.mark.xfail()), False],
)
def test_Rop_partially_differentiable_paths(self, use_op_rop_implementation):
# This test refers to a bug reported by Jeremiah Lowin on 18th Oct
# 2013. The bug consists when through a dot operation there is only
# one differentiable path (i.e. there is no gradient wrt to one of
......@@ -336,7 +370,12 @@ class TestRopLop(RopLopChecker):
grad(d, v),
v,
v,
disconnected_outputs="raise",
use_op_rop_implementation=use_op_rop_implementation,
# 2025: This is a tricky case, the gradient of the gradient does not depend on v
# although v still exists in the graph inside a `Second` operator.
# The original test was checking that Rop wouldn't raise an error, but Lop does.
# Since the correct behavior is ambiguous, I let both implementations off the hook.
disconnected_outputs="raise" if use_op_rop_implementation else "ignore",
)
# 2025: Here is an unambiguous test for the original commented issue:
......@@ -348,10 +387,11 @@ class TestRopLop(RopLopChecker):
out,
[x],
[x.type()],
use_op_rop_implementation=use_op_rop_implementation,
disconnected_outputs="raise",
)
# More extensive testing shows that the Rop implementation FAILS to raise when
# More extensive testing shows that the legacy Rop implementation FAILS to raise when
# the cost is linked through strictly non-differentiable paths.
# This is not Dot specific, we would observe the same with any operation where the gradient
# with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...)
......@@ -361,6 +401,7 @@ class TestRopLop(RopLopChecker):
out,
[x],
[x.type()],
use_op_rop_implementation=use_op_rop_implementation,
disconnected_outputs="raise",
)
......@@ -371,5 +412,6 @@ class TestRopLop(RopLopChecker):
out,
[x],
[x.type()],
use_op_rop_implementation=use_op_rop_implementation,
disconnected_outputs="raise",
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论