提交 b5a64c77 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Compute pushforward via double application of pullback

Also fixes bug in Scan L_op and Max R_op Co-authored-by: 's avatarAdrian Seyboldt <aseyboldt@users.noreply.github.com>
上级 84c78027
...@@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`. ...@@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`.
the outputs) back to their corresponding shapes and return them as the the outputs) back to their corresponding shapes and return them as the
output of the :meth:`Op.R_op` method. output of the :meth:`Op.R_op` method.
:ref:`List of op with r op support <R_op_list>`.
.. _libdoc_gradient:
===========================================
:mod:`gradient` -- Symbolic Differentiation
===========================================
.. module:: gradient
:platform: Unix, Windows
:synopsis: low-level automatic differentiation
.. moduleauthor:: LISA
.. testsetup:: *
from pytensor.gradient import *
Symbolic gradient is usually computed from :func:`gradient.grad`, which offers a
more convenient syntax for the common case of wanting the gradient of some
scalar cost with respect to some input expressions. The :func:`grad_sources_inputs`
function does the underlying work, and is more flexible, but is also more
awkward to use when :func:`gradient.grad` can do the job.
Gradient related functions
==========================
.. automodule:: pytensor.gradient
:members:
.. _R_op_list:
List of Implemented R op
========================
See the :ref:`gradient tutorial <tutcomputinggrads>` for the R op documentation.
list of ops that support R-op:
* with test
* SpecifyShape
* MaxAndArgmax
* Subtensor
* IncSubtensor set_subtensor too
* Alloc
* Dot
* Elemwise
* Sum
* Softmax
* Shape
* Join
* Rebroadcast
* Reshape
* DimShuffle
* Scan [In tests/scan/test_basic.test_rop]
* without test
* Split
* ARange
* ScalarFromTensor
* AdvancedSubtensor1
* AdvancedIncSubtensor1
* AdvancedIncSubtensor
Partial list of ops without support for R-op:
* All sparse ops
* All linear algebra ops.
* PermuteRowElements
* AdvancedSubtensor
* TensorDot
* Outer
* Prod
* MulwithoutZeros
* ProdWithoutZeros
* CAReduce(for max,... done for MaxAndArgmax op)
* MaxAndArgmax(only for matrix on axis 0 or 1)
...@@ -1791,5 +1791,3 @@ Gradient / Differentiation ...@@ -1791,5 +1791,3 @@ Gradient / Differentiation
:members: grad :members: grad
:noindex: :noindex:
See the :ref:`gradient <libdoc_gradient>` page for complete documentation
of the gradient module.
...@@ -86,9 +86,7 @@ of symbolic differentiation). ...@@ -86,9 +86,7 @@ of symbolic differentiation).
``i`` of the output list is the gradient of the first argument of ``i`` of the output list is the gradient of the first argument of
`pt.grad` with respect to the ``i``-th element of the list given as second argument. `pt.grad` with respect to the ``i``-th element of the list given as second argument.
The first argument of `pt.grad` has to be a scalar (a tensor The first argument of `pt.grad` has to be a scalar (a tensor
of size 1). For more information on the semantics of the arguments of of size 1).
`pt.grad` and details about the implementation, see
:ref:`this<libdoc_gradient>` section of the library.
Additional information on the inner workings of differentiation may also be Additional information on the inner workings of differentiation may also be
found in the more advanced tutorial :ref:`Extending PyTensor<extending>`. found in the more advanced tutorial :ref:`Extending PyTensor<extending>`.
...@@ -204,7 +202,21 @@ you need to do something similar to this: ...@@ -204,7 +202,21 @@ you need to do something similar to this:
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1]) >>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.]) array([ 2., 2.])
:ref:`List <R_op_list>` of Op that implement Rop. By default, the R-operator is implemented as a double application of the L_operator
(see `reference <https://j-towns.github.io/2017/06/12/A-new-trick.html>`_).
In most cases this should be as performant as a specialized implementation of the R-operator.
However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.
When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.
>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True)
>>> f = pytensor.function([W, V, x], JV)
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])
L-operator L-operator
---------- ----------
...@@ -234,7 +246,6 @@ array([[ 0., 0.], ...@@ -234,7 +246,6 @@ array([[ 0., 0.],
as the input parameter, while the result of the R-operator has a shape similar as the input parameter, while the result of the R-operator has a shape similar
to that of the output. to that of the output.
:ref:`List of op with r op support <R_op_list>`.
Hessian times a Vector Hessian times a Vector
====================== ======================
......
...@@ -340,6 +340,12 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -340,6 +340,12 @@ class OpFromGraph(Op, HasInnerGraph):
``None``, this will be used as the connection_pattern for this ``None``, this will be used as the connection_pattern for this
:class:`Op`. :class:`Op`.
.. warning::
rop overrides is ignored when `pytensor.gradient.Rop` is called with
`use_op_rop_implementation=False` (default). In this case the Lop
is used twice to obtain a mathematically equivalent Rop.
strict: bool, default False strict: bool, default False
If true, it raises when any variables needed to compute the inner graph If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with are not provided as explici inputs. This can only happen for graphs with
...@@ -641,7 +647,12 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -641,7 +647,12 @@ class OpFromGraph(Op, HasInnerGraph):
return rop_overrides return rop_overrides
eval_points = [inp_t() for inp_t in self.input_types] eval_points = [inp_t() for inp_t in self.input_types]
fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points) fn_rop = partial(
Rop,
wrt=inner_inputs,
eval_points=eval_points,
use_op_rop_implementation=True,
)
callable_args = (inner_inputs, eval_points) callable_args = (inner_inputs, eval_points)
if rop_overrides is None: if rop_overrides is None:
......
差异被折叠。
...@@ -3165,7 +3165,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3165,7 +3165,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rop_self_outputs = self_outputs rop_self_outputs = self_outputs
if info.n_shared_outs > 0: if info.n_shared_outs > 0:
rop_self_outputs = rop_self_outputs[: -info.n_shared_outs] rop_self_outputs = rop_self_outputs[: -info.n_shared_outs]
rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points) rop_outs = Rop(
rop_self_outputs,
rop_of_inputs,
inner_eval_points,
use_op_rop_implementation=True,
)
if not isinstance(rop_outs, list | tuple): if not isinstance(rop_outs, list | tuple):
rop_outs = [rop_outs] rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan # Step 2. Figure out what corresponds to what in the scan
......
...@@ -306,7 +306,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -306,7 +306,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
) )
def test_rop(self, cls_ofg): @pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_rop(self, cls_ofg, use_op_rop_implementation):
a = vector() a = vector()
M = matrix() M = matrix()
b = dot(a, M) b = dot(a, M)
...@@ -315,7 +316,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -315,7 +316,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
W = matrix() W = matrix()
y = op_matmul(x, W) y = op_matmul(x, W)
du = vector() du = vector()
dv = Rop(y, x, du) dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv) fn = function([x, W, du], dv)
xval = np.random.random((16,)).astype(config.floatX) xval = np.random.random((16,)).astype(config.floatX)
Wval = np.random.random((16, 16)).astype(config.floatX) Wval = np.random.random((16, 16)).astype(config.floatX)
...@@ -324,7 +325,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -324,7 +325,8 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
dvval2 = fn(xval, Wval, duval) dvval2 = fn(xval, Wval, duval)
np.testing.assert_array_almost_equal(dvval2, dvval, 4) np.testing.assert_array_almost_equal(dvval2, dvval, 4)
def test_rop_multiple_outputs(self): @pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_rop_multiple_outputs(self, use_op_rop_implementation):
a = vector() a = vector()
M = matrix() M = matrix()
b = dot(a, M) b = dot(a, M)
...@@ -339,21 +341,21 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -339,21 +341,21 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
duval = np.random.random((16,)).astype(config.floatX) duval = np.random.random((16,)).astype(config.floatX)
y = op_matmul(x, W)[0] y = op_matmul(x, W)[0]
dv = Rop(y, x, du) dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv) fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval) result_dvval = fn(xval, Wval, duval)
expected_dvval = np.dot(duval, Wval) expected_dvval = np.dot(duval, Wval)
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
y = op_matmul(x, W)[1] y = op_matmul(x, W)[1]
dv = Rop(y, x, du) dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv) fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval) result_dvval = fn(xval, Wval, duval)
expected_dvval = -np.dot(duval, Wval) expected_dvval = -np.dot(duval, Wval)
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
y = pt.add(*op_matmul(x, W)) y = pt.add(*op_matmul(x, W))
dv = Rop(y, x, du) dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
fn = function([x, W, du], dv) fn = function([x, W, du], dv)
result_dvval = fn(xval, Wval, duval) result_dvval = fn(xval, Wval, duval)
expected_dvval = np.zeros_like(np.dot(duval, Wval)) expected_dvval = np.zeros_like(np.dot(duval, Wval))
...@@ -362,7 +364,16 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -362,7 +364,16 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
) )
def test_rop_override(self, cls_ofg): @pytest.mark.parametrize(
"use_op_rop_implementation",
[
True,
pytest.param(
False, marks=pytest.mark.xfail(reason="Custom ROp is ignored")
),
],
)
def test_rop_override(self, cls_ofg, use_op_rop_implementation):
x, y = vectors("xy") x, y = vectors("xy")
def ro(inps, epts): def ro(inps, epts):
...@@ -380,7 +391,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -380,7 +391,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
du, dv = vector("du"), vector("dv") du, dv = vector("du"), vector("dv")
for op in [op_mul, op_mul2]: for op in [op_mul, op_mul2]:
zz = op_mul(xx, yy) zz = op_mul(xx, yy)
dw = Rop(zz, [xx, yy], [du, dv]) dw = Rop(
zz,
[xx, yy],
[du, dv],
use_op_rop_implementation=use_op_rop_implementation,
)
fn = function([xx, yy, du, dv], dw) fn = function([xx, yy, du, dv], dw)
vals = np.random.random((4, 32)).astype(config.floatX) vals = np.random.random((4, 32)).astype(config.floatX)
dwval = fn(*vals) dwval = fn(*vals)
......
...@@ -1922,7 +1922,8 @@ class TestScan: ...@@ -1922,7 +1922,8 @@ class TestScan:
fgrad = function([], g_sh) fgrad = function([], g_sh)
assert fgrad() == 1 assert fgrad() == 1
def test_R_op(self): @pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_R_op(self, use_op_rop_implementation):
seed = utt.fetch_seed() seed = utt.fetch_seed()
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
floatX = config.floatX floatX = config.floatX
...@@ -1957,9 +1958,9 @@ class TestScan: ...@@ -1957,9 +1958,9 @@ class TestScan:
eh0 = vector("eh0") eh0 = vector("eh0")
eW = matrix("eW") eW = matrix("eW")
nwo_u = Rop(o, _u, eu) nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation)
nwo_h0 = Rop(o, _h0, eh0) nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation)
nwo_W = Rop(o, _W, eW) nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation)
fn_rop = function( fn_rop = function(
[u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore" [u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore"
) )
...@@ -1997,7 +1998,8 @@ class TestScan: ...@@ -1997,7 +1998,8 @@ class TestScan:
np.testing.assert_allclose(vnW, tnW, atol=1e-6) np.testing.assert_allclose(vnW, tnW, atol=1e-6)
@pytest.mark.slow @pytest.mark.slow
def test_R_op_2(self): @pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_R_op_2(self, use_op_rop_implementation):
seed = utt.fetch_seed() seed = utt.fetch_seed()
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
floatX = config.floatX floatX = config.floatX
...@@ -2040,9 +2042,9 @@ class TestScan: ...@@ -2040,9 +2042,9 @@ class TestScan:
eh0 = vector("eh0") eh0 = vector("eh0")
eW = matrix("eW") eW = matrix("eW")
nwo_u = Rop(o, _u, eu) nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation)
nwo_h0 = Rop(o, _h0, eh0) nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation)
nwo_W = Rop(o, _W, eW) nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation)
fn_rop = function( fn_rop = function(
[u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W, o], on_unused_input="ignore" [u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W, o], on_unused_input="ignore"
) )
...@@ -2078,7 +2080,8 @@ class TestScan: ...@@ -2078,7 +2080,8 @@ class TestScan:
np.testing.assert_allclose(vnh0, tnh0, atol=1e-6) np.testing.assert_allclose(vnh0, tnh0, atol=1e-6)
np.testing.assert_allclose(vnW, tnW, atol=2e-6) np.testing.assert_allclose(vnW, tnW, atol=2e-6)
def test_R_op_mitmot(self): @pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_R_op_mitmot(self, use_op_rop_implementation):
# this test is a copy paste from the script given by Justin Bayer to # this test is a copy paste from the script given by Justin Bayer to
# reproduce this bug # reproduce this bug
# We have 2 parameter groups with the following shapes. # We have 2 parameter groups with the following shapes.
...@@ -2126,7 +2129,12 @@ class TestScan: ...@@ -2126,7 +2129,12 @@ class TestScan:
p = dvector() p = dvector()
# TODO: We should test something about the Rop! # TODO: We should test something about the Rop!
Rop(d_cost_wrt_pars, pars, p) Rop(
d_cost_wrt_pars,
pars,
p,
use_op_rop_implementation=use_op_rop_implementation,
)
def test_second_derivative_disconnected_cost_with_mit_mot(self): def test_second_derivative_disconnected_cost_with_mit_mot(self):
# This test is a regression test for a bug that was revealed # This test is a regression test for a bug that was revealed
......
...@@ -49,9 +49,12 @@ def test_matrix_inverse_rop_lop(): ...@@ -49,9 +49,12 @@ def test_matrix_inverse_rop_lop():
v = vector("v") v = vector("v")
y = MatrixInverse()(mx).sum(axis=0) y = MatrixInverse()(mx).sum(axis=0)
yv = pytensor.gradient.Rop(y, mx, mv) yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True)
rop_f = function([mx, mv], yv) rop_f = function([mx, mv], yv)
yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False)
rop_via_lop_f = function([mx, mv], yv_via_lop)
sy, _ = pytensor.scan( sy, _ = pytensor.scan(
lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(), lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(),
sequences=pt.arange(y.shape[0]), sequences=pt.arange(y.shape[0]),
...@@ -65,10 +68,14 @@ def test_matrix_inverse_rop_lop(): ...@@ -65,10 +68,14 @@ def test_matrix_inverse_rop_lop():
v_ref = scan_f(vx, vv) v_ref = scan_f(vx, vv)
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol)
with pytest.raises(ValueError): with pytest.raises(ValueError):
pytensor.gradient.Rop( pytensor.gradient.Rop(
pytensor.clone_replace(y, replace={mx: break_op(mx)}), mx, mv pytensor.clone_replace(y, replace={mx: break_op(mx)}),
mx,
mv,
use_op_rop_implementation=True,
) )
vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX) vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX)
......
...@@ -88,7 +88,7 @@ class RopLopChecker: ...@@ -88,7 +88,7 @@ class RopLopChecker:
test that an error is raised. test that an error is raised.
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
Rop(y, x, v) Rop(y, x, v, use_op_rop_implementation=True)
def check_mat_rop_lop(self, y, out_shape): def check_mat_rop_lop(self, y, out_shape):
""" """
...@@ -116,8 +116,14 @@ class RopLopChecker: ...@@ -116,8 +116,14 @@ class RopLopChecker:
vv = np.asarray( vv = np.asarray(
self.rng.uniform(size=self.mat_in_shape), pytensor.config.floatX self.rng.uniform(size=self.mat_in_shape), pytensor.config.floatX
) )
yv = Rop(y, self.mx, self.mv) yv = Rop(y, self.mx, self.mv, use_op_rop_implementation=True)
rop_f = function([self.mx, self.mv], yv, on_unused_input="ignore") rop_f = function([self.mx, self.mv], yv, on_unused_input="ignore")
yv_through_lop = Rop(y, self.mx, self.mv, use_op_rop_implementation=False)
rop_through_lop_f = function(
[self.mx, self.mv], yv_through_lop, on_unused_input="ignore"
)
sy, _ = pytensor.scan( sy, _ = pytensor.scan(
lambda i, y, x, v: (grad(y[i], x) * v).sum(), lambda i, y, x, v: (grad(y[i], x) * v).sum(),
sequences=pt.arange(y.shape[0]), sequences=pt.arange(y.shape[0]),
...@@ -127,6 +133,7 @@ class RopLopChecker: ...@@ -127,6 +133,7 @@ class RopLopChecker:
v_ref = scan_f(vx, vv) v_ref = scan_f(vx, vv)
np.testing.assert_allclose(rop_f(vx, vv), v_ref) np.testing.assert_allclose(rop_f(vx, vv), v_ref)
np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref)
self.check_nondiff_rop( self.check_nondiff_rop(
pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}), pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}),
...@@ -156,8 +163,14 @@ class RopLopChecker: ...@@ -156,8 +163,14 @@ class RopLopChecker:
vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
yv = Rop(y, self.x, self.v) yv = Rop(y, self.x, self.v, use_op_rop_implementation=True)
rop_f = function([self.x, self.v], yv, on_unused_input="ignore") rop_f = function([self.x, self.v], yv, on_unused_input="ignore")
yv_through_lop = Rop(y, self.x, self.v, use_op_rop_implementation=False)
rop_through_lop_f = function(
[self.x, self.v], yv_through_lop, on_unused_input="ignore"
)
J, _ = pytensor.scan( J, _ = pytensor.scan(
lambda i, y, x: grad(y[i], x), lambda i, y, x: grad(y[i], x),
sequences=pt.arange(y.shape[0]), sequences=pt.arange(y.shape[0]),
...@@ -168,6 +181,7 @@ class RopLopChecker: ...@@ -168,6 +181,7 @@ class RopLopChecker:
v_ref = scan_f(vx, vv) v_ref = scan_f(vx, vv)
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref, rtol=rtol)
if check_nondiff_rop: if check_nondiff_rop:
self.check_nondiff_rop( self.check_nondiff_rop(
...@@ -255,12 +269,12 @@ class TestRopLop(RopLopChecker): ...@@ -255,12 +269,12 @@ class TestRopLop(RopLopChecker):
insh = self.in_shape[0] insh = self.in_shape[0]
vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX) vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX)
W = pytensor.shared(vW) W = pytensor.shared(vW)
# check_nondiff_rop reveals an error in how Rop handles non-differentiable paths # check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths
# See: test_Rop_partially_differentiable_paths # See: test_Rop_partially_differentiable_paths
self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False) self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False)
def test_elemwise0(self): def test_elemwise0(self):
# check_nondiff_rop reveals an error in how Rop handles non-differentiable paths # check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths
# See: test_Rop_partially_differentiable_paths # See: test_Rop_partially_differentiable_paths
self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False) self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False)
...@@ -294,11 +308,18 @@ class TestRopLop(RopLopChecker): ...@@ -294,11 +308,18 @@ class TestRopLop(RopLopChecker):
self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0], self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0],
) )
def test_invalid_input(self): @pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_invalid_input(self, use_op_rop_implementation):
with pytest.raises(ValueError): with pytest.raises(ValueError):
Rop(0.0, [matrix()], [vector()]) Rop(
0.0,
[matrix()],
[vector()],
use_op_rop_implementation=use_op_rop_implementation,
)
def test_multiple_outputs(self): @pytest.mark.parametrize("use_op_rop_implementation", [True, False])
def test_multiple_outputs(self, use_op_rop_implementation):
m = matrix("m") m = matrix("m")
v = vector("v") v = vector("v")
m_ = matrix("m_") m_ = matrix("m_")
...@@ -309,10 +330,20 @@ class TestRopLop(RopLopChecker): ...@@ -309,10 +330,20 @@ class TestRopLop(RopLopChecker):
m_val = self.rng.uniform(size=(3, 7)).astype(pytensor.config.floatX) m_val = self.rng.uniform(size=(3, 7)).astype(pytensor.config.floatX)
v_val = self.rng.uniform(size=(7,)).astype(pytensor.config.floatX) v_val = self.rng.uniform(size=(7,)).astype(pytensor.config.floatX)
rop_out1 = Rop([m, v, m + v], [m, v], [m_, v_]) rop_out1 = Rop(
[m, v, m + v],
[m, v],
[m_, v_],
use_op_rop_implementation=use_op_rop_implementation,
)
assert isinstance(rop_out1, list) assert isinstance(rop_out1, list)
assert len(rop_out1) == 3 assert len(rop_out1) == 3
rop_out2 = Rop((m, v, m + v), [m, v], [m_, v_]) rop_out2 = Rop(
(m, v, m + v),
[m, v],
[m_, v_],
use_op_rop_implementation=use_op_rop_implementation,
)
assert isinstance(rop_out2, tuple) assert isinstance(rop_out2, tuple)
assert len(rop_out2) == 3 assert len(rop_out2) == 3
...@@ -322,8 +353,11 @@ class TestRopLop(RopLopChecker): ...@@ -322,8 +353,11 @@ class TestRopLop(RopLopChecker):
f = pytensor.function([m, v, m_, v_], all_outs) f = pytensor.function([m, v, m_, v_], all_outs)
f(mval, vval, m_val, v_val) f(mval, vval, m_val, v_val)
@pytest.mark.xfail() @pytest.mark.parametrize(
def test_Rop_partially_differentiable_paths(self): "use_op_rop_implementation",
[pytest.param(True, marks=pytest.mark.xfail()), False],
)
def test_Rop_partially_differentiable_paths(self, use_op_rop_implementation):
# This test refers to a bug reported by Jeremiah Lowin on 18th Oct # This test refers to a bug reported by Jeremiah Lowin on 18th Oct
# 2013. The bug consists when through a dot operation there is only # 2013. The bug consists when through a dot operation there is only
# one differentiable path (i.e. there is no gradient wrt to one of # one differentiable path (i.e. there is no gradient wrt to one of
...@@ -336,7 +370,12 @@ class TestRopLop(RopLopChecker): ...@@ -336,7 +370,12 @@ class TestRopLop(RopLopChecker):
grad(d, v), grad(d, v),
v, v,
v, v,
disconnected_outputs="raise", use_op_rop_implementation=use_op_rop_implementation,
# 2025: This is a tricky case, the gradient of the gradient does not depend on v
# although v still exists in the graph inside a `Second` operator.
# The original test was checking that Rop wouldn't raise an error, but Lop does.
# Since the correct behavior is ambiguous, I let both implementations off the hook.
disconnected_outputs="raise" if use_op_rop_implementation else "ignore",
) )
# 2025: Here is an unambiguous test for the original commented issue: # 2025: Here is an unambiguous test for the original commented issue:
...@@ -348,10 +387,11 @@ class TestRopLop(RopLopChecker): ...@@ -348,10 +387,11 @@ class TestRopLop(RopLopChecker):
out, out,
[x], [x],
[x.type()], [x.type()],
use_op_rop_implementation=use_op_rop_implementation,
disconnected_outputs="raise", disconnected_outputs="raise",
) )
# More extensive testing shows that the Rop implementation FAILS to raise when # More extensive testing shows that the legacy Rop implementation FAILS to raise when
# the cost is linked through strictly non-differentiable paths. # the cost is linked through strictly non-differentiable paths.
# This is not Dot specific, we would observe the same with any operation where the gradient # This is not Dot specific, we would observe the same with any operation where the gradient
# with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...) # with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...)
...@@ -361,6 +401,7 @@ class TestRopLop(RopLopChecker): ...@@ -361,6 +401,7 @@ class TestRopLop(RopLopChecker):
out, out,
[x], [x],
[x.type()], [x.type()],
use_op_rop_implementation=use_op_rop_implementation,
disconnected_outputs="raise", disconnected_outputs="raise",
) )
...@@ -371,5 +412,6 @@ class TestRopLop(RopLopChecker): ...@@ -371,5 +412,6 @@ class TestRopLop(RopLopChecker):
out, out,
[x], [x],
[x.type()], [x.type()],
use_op_rop_implementation=use_op_rop_implementation,
disconnected_outputs="raise", disconnected_outputs="raise",
) )
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论