提交 ef5bcb50 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test SVD and Eig(h): allow benign sign change

上级 ebe518fa
......@@ -165,10 +165,33 @@ class TestSvd(utt.InferShapeTester):
pt_outputs = fn(a)
np_outputs = np_outputs if isinstance(np_outputs, tuple) else [np_outputs]
if compute_uv:
# In this case we sometimes get a sign flip on some columns in one impl and not the thore
# The results are both correct, and we test that by reconstructing the original input
U, S, Vh = pt_outputs
S_diag = np.expand_dims(S, -2) * np.eye(S.shape[-1])
diff = a.shape[-2] - a.shape[-1]
if full_matrix:
if diff > 0:
# tall
S_diag = np.pad(S_diag, [(0, 0), (0, diff), (0, 0)][-a.ndim :])
elif diff < 0:
# wide
S_diag = np.pad(S_diag, [(0, 0), (0, 0), (0, -diff)][-a.ndim :])
a_r = U @ S_diag @ Vh
rtol = 1e-3 if config.floatX == "float32" else 1e-7
np.testing.assert_allclose(a_r, a, rtol=rtol)
for np_val, pt_val in zip(np_outputs, pt_outputs, strict=True):
# Check values are equivalent up to sign change
np.testing.assert_allclose(np.abs(np_val), np.abs(pt_val), rtol=rtol)
rtol = 1e-5 if config.floatX == "float32" else 1e-7
for np_val, pt_val in zip(np_outputs, pt_outputs, strict=True):
np.testing.assert_allclose(np_val, pt_val, rtol=rtol)
else:
rtol = 1e-5 if config.floatX == "float32" else 1e-7
for np_val, pt_val in zip(np_outputs, pt_outputs, strict=True):
np.testing.assert_allclose(np_val, pt_val, rtol=rtol)
def test_svd_infer_shape(self):
self.validate_shape((4, 4), full_matrices=True, compute_uv=True)
......@@ -428,8 +451,8 @@ class TestEig(utt.InferShapeTester):
w, v = fn(A_val)
w_np, v_np = np.linalg.eig(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
np.testing.assert_allclose(np.abs(w), np.abs(w_np))
np.testing.assert_allclose(np.abs(v), np.abs(v_np))
assert_array_almost_equal(np.dot(A_val, v), w * v)
# Asymmetric input (real eigenvalues)
......@@ -438,16 +461,16 @@ class TestEig(utt.InferShapeTester):
w, v = fn(A_val)
w_np, v_np = np.linalg.eig(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
np.testing.assert_allclose(np.abs(w), np.abs(w_np))
np.testing.assert_allclose(np.abs(v), np.abs(v_np))
assert_array_almost_equal(np.dot(A_val, v), w * v)
# Asymmetric input (complex eigenvalues)
A_val = self.rng.normal(size=(5, 5))
w, v = fn(A_val)
w_np, v_np = np.linalg.eig(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
np.testing.assert_allclose(np.abs(w), np.abs(w_np))
np.testing.assert_allclose(np.abs(v), np.abs(v_np))
assert_array_almost_equal(np.dot(A_val, v), w * v)
......@@ -464,11 +487,14 @@ class TestEigh(TestEig):
w, v = fn(A_val)
w_np, v_np = np.linalg.eigh(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
# There are multiple valid results up to some sign changes
# Check we can reconstruct input
rtol = 1e-2 if self.dtype == "float32" else 1e-7
np.testing.assert_allclose(np.dot(A_val, v), w * v, rtol=rtol)
np.testing.assert_allclose(np.abs(w), np.abs(w_np), rtol=rtol)
np.testing.assert_allclose(np.abs(v), np.abs(v_np), rtol=rtol)
def test_uplo(self):
S = self.S
a = matrix(dtype=self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论