提交 49cf9d22 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cleanup Rop tests and fix Max Rop implementation

上级 298bb133
...@@ -431,20 +431,25 @@ class Max(NonZeroDimsCAReduce): ...@@ -431,20 +431,25 @@ class Max(NonZeroDimsCAReduce):
return (g_x,) return (g_x,)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
[x] = inputs
if eval_points[0] is None: if eval_points[0] is None:
return [None, None] return [None]
if len(self.axis) != 1: axis = tuple(range(x.ndim) if self.axis is None else self.axis)
raise ValueError("R_op supported for max only for one axis!") if isinstance(axis, int):
if self.axis[0] > 1: axis = [axis]
raise ValueError("R_op supported for max only when axis is 0 or 1") if len(axis) != 1:
raise NotImplementedError("R_op supported for max only for one axis!")
if axis[0] > 1:
raise NotImplementedError("R_op supported for max only when axis is 0 or 1")
if inputs[0].ndim != 2: if inputs[0].ndim != 2:
raise ValueError("R_op supported for max only when input is a matrix") raise NotImplementedError(
max_pos = Argmax(self.axis).make_node(*inputs).outputs "R_op supported for max only when input is a matrix"
# print(eval_points[0].eval()) )
max_pos = Argmax(self.axis)(*inputs)
if self.axis[0] == 0: if self.axis[0] == 0:
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None] return [eval_points[0][max_pos, arange(eval_points[0].shape[1])]]
else: else:
return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None] return [eval_points[0][arange(eval_points[0].shape[0]), max_pos]]
class Min(NonZeroDimsCAReduce): class Min(NonZeroDimsCAReduce):
......
...@@ -1992,9 +1992,9 @@ class TestScan: ...@@ -1992,9 +1992,9 @@ class TestScan:
vnu, vnh0, vnW = fn_rop(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) vnu, vnh0, vnW = fn_rop(v_u, v_h0, v_W, v_eu, v_eh0, v_eW)
tnu, tnh0, tnW = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) tnu, tnh0, tnW = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW)
utt.assert_allclose(vnu, tnu, atol=1e-6) np.testing.assert_allclose(vnu, tnu, atol=1e-6)
utt.assert_allclose(vnh0, tnh0, atol=1e-6) np.testing.assert_allclose(vnh0, tnh0, atol=1e-6)
utt.assert_allclose(vnW, tnW, atol=1e-6) np.testing.assert_allclose(vnW, tnW, atol=1e-6)
@pytest.mark.slow @pytest.mark.slow
def test_R_op_2(self): def test_R_op_2(self):
...@@ -2074,9 +2074,9 @@ class TestScan: ...@@ -2074,9 +2074,9 @@ class TestScan:
) )
tnu, tnh0, tnW, tno = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) tnu, tnh0, tnW, tno = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW)
utt.assert_allclose(vnu, tnu, atol=1e-6) np.testing.assert_allclose(vnu, tnu, atol=1e-6)
utt.assert_allclose(vnh0, tnh0, atol=1e-6) np.testing.assert_allclose(vnh0, tnh0, atol=1e-6)
utt.assert_allclose(vnW, tnW, atol=2e-6) np.testing.assert_allclose(vnW, tnW, atol=2e-6)
def test_R_op_mitmot(self): def test_R_op_mitmot(self):
# this test is a copy paste from the script given by Justin Bayer to # this test is a copy paste from the script given by Justin Bayer to
...@@ -2094,13 +2094,10 @@ class TestScan: ...@@ -2094,13 +2094,10 @@ class TestScan:
W1 = pars[:3].reshape(W1shape) W1 = pars[:3].reshape(W1shape)
W2 = pars[3:].reshape(W2shape) W2 = pars[3:].reshape(W2shape)
# Define recurrent model. We are using a model where each input is a # Define recurrent model. We are using a model where each input
# tensor # is a tensor of shape (T, B, D) where T is the number of timesteps,
# of shape (T, B, D) where T is the number of timesteps, B is the # B is the number of sequences iterated over in parallel and
# number of # D is the dimensionality of each item at a timestep.
# sequences iterated over in parallel and D is the dimensionality of
# each
# item at a timestep.
inpt = tensor3("inpt") inpt = tensor3("inpt")
target = tensor3("target") target = tensor3("target")
...@@ -2128,6 +2125,7 @@ class TestScan: ...@@ -2128,6 +2125,7 @@ class TestScan:
d_cost_wrt_pars = grad(cost, pars) d_cost_wrt_pars = grad(cost, pars)
p = dvector() p = dvector()
# TODO: We should test something about the Rop!
Rop(d_cost_wrt_pars, pars, p) Rop(d_cost_wrt_pars, pars, p)
......
...@@ -14,7 +14,7 @@ from pytensor.graph.rewriting.utils import rewrite_graph ...@@ -14,7 +14,7 @@ from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor import swapaxes from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul from pytensor.tensor.math import dot, matmul
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD, SVD,
Det, Det,
...@@ -42,7 +42,8 @@ from tests import unittest_tools as utt ...@@ -42,7 +42,8 @@ from tests import unittest_tools as utt
from tests.test_rop import break_op from tests.test_rop import break_op
def test_rop_lop(): def test_matrix_inverse_rop_lop():
rtol = 1e-7 if config.floatX == "float64" else 1e-5
mx = matrix("mx") mx = matrix("mx")
mv = matrix("mv") mv = matrix("mv")
v = vector("v") v = vector("v")
...@@ -62,23 +63,13 @@ def test_rop_lop(): ...@@ -62,23 +63,13 @@ def test_rop_lop():
vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX)
vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX)
v1 = rop_f(vx, vv) v_ref = scan_f(vx, vv)
v2 = scan_f(vx, vv) np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
assert _allclose(v1, v2), f"ROP mismatch: {v1} {v2}" with pytest.raises(ValueError):
raised = False
try:
pytensor.gradient.Rop( pytensor.gradient.Rop(
pytensor.clone_replace(y, replace={mx: break_op(mx)}), mx, mv pytensor.clone_replace(y, replace={mx: break_op(mx)}), mx, mv
) )
except ValueError:
raised = True
if not raised:
raise Exception(
"Op did not raised an error even though the function"
" is not differentiable"
)
vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX) vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX)
yv = pytensor.gradient.Lop(y, mx, v) yv = pytensor.gradient.Lop(y, mx, v)
...@@ -87,9 +78,9 @@ def test_rop_lop(): ...@@ -87,9 +78,9 @@ def test_rop_lop():
sy = pytensor.gradient.grad((v * y).sum(), mx) sy = pytensor.gradient.grad((v * y).sum(), mx)
scan_f = function([mx, v], sy) scan_f = function([mx, v], sy)
v1 = lop_f(vx, vv) v_ref = scan_f(vx, vv)
v2 = scan_f(vx, vv) v = lop_f(vx, vv)
assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}" np.testing.assert_allclose(v, v_ref, rtol=rtol)
def test_transinv_to_invtrans(): def test_transinv_to_invtrans():
......
...@@ -603,7 +603,7 @@ class TestSpecifyBroadcastable: ...@@ -603,7 +603,7 @@ class TestSpecifyBroadcastable:
class TestRopLop(RopLopChecker): class TestRopLop(RopLopChecker):
def test_shape(self): def test_shape(self):
self.check_nondiff_rop(self.x.shape[0]) self.check_nondiff_rop(self.x.shape[0], self.x, self.v)
def test_specifyshape(self): def test_specifyshape(self):
self.check_rop_lop(specify_shape(self.x, self.in_shape), self.in_shape) self.check_rop_lop(specify_shape(self.x, self.in_shape), self.in_shape)
......
...@@ -16,8 +16,14 @@ import pytest ...@@ -16,8 +16,14 @@ import pytest
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import function from pytensor import config, function
from pytensor.gradient import Lop, Rop, grad, grad_undefined from pytensor.gradient import (
Lop,
NullTypeGradError,
Rop,
grad,
grad_undefined,
)
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.tensor.math import argmax, dot from pytensor.tensor.math import argmax, dot
...@@ -61,6 +67,10 @@ class RopLopChecker: ...@@ -61,6 +67,10 @@ class RopLopChecker:
Rop to class that inherit from it. Rop to class that inherit from it.
""" """
@staticmethod
def rtol():
return 1e-7 if config.floatX == "float64" else 1e-5
def setup_method(self): def setup_method(self):
# Using vectors make things a lot simpler for generating the same # Using vectors make things a lot simpler for generating the same
# computations using scan # computations using scan
...@@ -72,13 +82,13 @@ class RopLopChecker: ...@@ -72,13 +82,13 @@ class RopLopChecker:
self.mv = matrix("mv") self.mv = matrix("mv")
self.mat_in_shape = (5 + self.rng.integers(3), 5 + self.rng.integers(3)) self.mat_in_shape = (5 + self.rng.integers(3), 5 + self.rng.integers(3))
def check_nondiff_rop(self, y): def check_nondiff_rop(self, y, x, v):
""" """
If your op is not differentiable(so you can't define Rop) If your op is not differentiable(so you can't define Rop)
test that an error is raised. test that an error is raised.
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
Rop(y, self.x, self.v) Rop(y, x, v)
def check_mat_rop_lop(self, y, out_shape): def check_mat_rop_lop(self, y, out_shape):
""" """
...@@ -115,13 +125,13 @@ class RopLopChecker: ...@@ -115,13 +125,13 @@ class RopLopChecker:
) )
scan_f = function([self.mx, self.mv], sy, on_unused_input="ignore") scan_f = function([self.mx, self.mv], sy, on_unused_input="ignore")
v1 = rop_f(vx, vv) v_ref = scan_f(vx, vv)
v2 = scan_f(vx, vv) np.testing.assert_allclose(rop_f(vx, vv), v_ref)
assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}"
self.check_nondiff_rop( self.check_nondiff_rop(
pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}) pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}),
self.mx,
self.mv,
) )
vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX)
...@@ -131,15 +141,17 @@ class RopLopChecker: ...@@ -131,15 +141,17 @@ class RopLopChecker:
sy = grad((self.v * y).sum(), self.mx) sy = grad((self.v * y).sum(), self.mx)
scan_f = function([self.mx, self.v], sy) scan_f = function([self.mx, self.v], sy)
v1 = lop_f(vx, vv) v = lop_f(vx, vv)
v2 = scan_f(vx, vv) v_ref = scan_f(vx, vv)
assert np.allclose(v1, v2), f"LOP mismatch: {v1} {v2}" np.testing.assert_allclose(v, v_ref)
def check_rop_lop(self, y, out_shape): def check_rop_lop(self, y, out_shape, check_nondiff_rop: bool = True):
""" """
As check_mat_rop_lop, except the input is self.x which is a As check_mat_rop_lop, except the input is self.x which is a
vector. The output is still a vector. vector. The output is still a vector.
""" """
rtol = self.rtol()
# TEST ROP # TEST ROP
vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
...@@ -152,24 +164,17 @@ class RopLopChecker: ...@@ -152,24 +164,17 @@ class RopLopChecker:
non_sequences=[y, self.x], non_sequences=[y, self.x],
) )
sy = dot(J, self.v) sy = dot(J, self.v)
scan_f = function([self.x, self.v], sy, on_unused_input="ignore") scan_f = function([self.x, self.v], sy, on_unused_input="ignore")
v1 = rop_f(vx, vv) v_ref = scan_f(vx, vv)
v2 = scan_f(vx, vv) np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}"
try: if check_nondiff_rop:
Rop( self.check_nondiff_rop(
pytensor.clone_replace(y, replace={self.x: break_op(self.x)}), pytensor.clone_replace(y, replace={self.x: break_op(self.x)}),
self.x, self.x,
self.v, self.v,
) )
except ValueError:
pytest.skip(
"Rop does not handle non-differentiable inputs "
"correctly. Bug exposed by fixing Add.grad method."
)
vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX)
...@@ -182,22 +187,20 @@ class RopLopChecker: ...@@ -182,22 +187,20 @@ class RopLopChecker:
non_sequences=[y, self.x], non_sequences=[y, self.x],
) )
sy = dot(self.v, J) sy = dot(self.v, J)
scan_f = function([self.x, self.v], sy) scan_f = function([self.x, self.v], sy)
v1 = lop_f(vx, vv) v = lop_f(vx, vv)
v2 = scan_f(vx, vv) v_ref = scan_f(vx, vv)
assert np.allclose(v1, v2), f"LOP mismatch: {v1} {v2}" np.testing.assert_allclose(v, v_ref, rtol=rtol)
class TestRopLop(RopLopChecker): class TestRopLop(RopLopChecker):
def test_max(self): def test_max(self):
# self.check_mat_rop_lop(pt_max(self.mx, axis=[0,1])[0], ())
self.check_mat_rop_lop(pt_max(self.mx, axis=0), (self.mat_in_shape[1],)) self.check_mat_rop_lop(pt_max(self.mx, axis=0), (self.mat_in_shape[1],))
self.check_mat_rop_lop(pt_max(self.mx, axis=1), (self.mat_in_shape[0],)) self.check_mat_rop_lop(pt_max(self.mx, axis=1), (self.mat_in_shape[0],))
def test_argmax(self): def test_argmax(self):
self.check_nondiff_rop(argmax(self.mx, axis=1)) self.check_nondiff_rop(argmax(self.mx, axis=1), self.mx, self.mv)
def test_subtensor(self): def test_subtensor(self):
self.check_rop_lop(self.x[:4], (4,)) self.check_rop_lop(self.x[:4], (4,))
...@@ -252,10 +255,14 @@ class TestRopLop(RopLopChecker): ...@@ -252,10 +255,14 @@ class TestRopLop(RopLopChecker):
insh = self.in_shape[0] insh = self.in_shape[0]
vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX) vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX)
W = pytensor.shared(vW) W = pytensor.shared(vW)
self.check_rop_lop(dot(self.x, W), self.in_shape) # check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
# See: test_Rop_partially_differentiable_paths
self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False)
def test_elemwise0(self): def test_elemwise0(self):
self.check_rop_lop((self.x + 1) ** 2, self.in_shape) # check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
# See: test_Rop_partially_differentiable_paths
self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False)
def test_elemwise1(self): def test_elemwise1(self):
self.check_rop_lop(self.x + pt.cast(self.x, "int32"), self.in_shape) self.check_rop_lop(self.x + pt.cast(self.x, "int32"), self.in_shape)
...@@ -288,15 +295,8 @@ class TestRopLop(RopLopChecker): ...@@ -288,15 +295,8 @@ class TestRopLop(RopLopChecker):
) )
def test_invalid_input(self): def test_invalid_input(self):
success = False with pytest.raises(ValueError):
try:
Rop(0.0, [matrix()], [vector()]) Rop(0.0, [matrix()], [vector()])
success = True
except ValueError:
pass
assert not success
def test_multiple_outputs(self): def test_multiple_outputs(self):
m = matrix("m") m = matrix("m")
...@@ -322,12 +322,54 @@ class TestRopLop(RopLopChecker): ...@@ -322,12 +322,54 @@ class TestRopLop(RopLopChecker):
f = pytensor.function([m, v, m_, v_], all_outs) f = pytensor.function([m, v, m_, v_], all_outs)
f(mval, vval, m_val, v_val) f(mval, vval, m_val, v_val)
def test_Rop_dot_bug_18Oct2013_Jeremiah(self): @pytest.mark.xfail()
def test_Rop_partially_differentiable_paths(self):
# This test refers to a bug reported by Jeremiah Lowin on 18th Oct # This test refers to a bug reported by Jeremiah Lowin on 18th Oct
# 2013. The bug consists when through a dot operation there is only # 2013. The bug consists when through a dot operation there is only
# one differentiable path (i.e. there is no gradient wrt to one of # one differentiable path (i.e. there is no gradient wrt to one of
# the inputs). # the inputs).
x = pt.arange(20.0).reshape([1, 20]) x = pt.arange(20.0).reshape([1, 20])
v = pytensor.shared(np.ones([20])) v = pytensor.shared(np.ones([20]), name="v")
d = dot(x, v).sum() d = dot(x, v).sum()
Rop(grad(d, v), v, v)
Rop(
grad(d, v),
v,
v,
disconnected_outputs="raise",
)
# 2025: Here is an unambiguous test for the original commented issue:
x = pt.matrix("x")
y = pt.matrix("y")
out = dot(x, break_op(y)).sum()
# Should not raise an error
Rop(
out,
[x],
[x.type()],
disconnected_outputs="raise",
)
# More extensive testing shows that the Rop implementation FAILS to raise when
# the cost is linked through strictly non-differentiable paths.
# This is not Dot specific, we would observe the same with any operation where the gradient
# with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...)
out = dot(break_op(x), y).sum()
with pytest.raises((ValueError, NullTypeGradError)):
Rop(
out,
[x],
[x.type()],
disconnected_outputs="raise",
)
# Only when both paths are non-differentiable is an error correctly raised again.
out = dot(break_op(x), break_op(y)).sum()
with pytest.raises((ValueError, NullTypeGradError)):
Rop(
out,
[x],
[x.type()],
disconnected_outputs="raise",
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论