Unverified 提交 bf628c97 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Allow `transposed` argument in `linalg.solve` (#1231)

* Add transposed argument to `solve` and `solve_triangular` * Expand test coverage for `Solve` and `SolveTriangular`
上级 757a10cd
...@@ -53,7 +53,6 @@ def jax_funcify_Solve(op, **kwargs): ...@@ -53,7 +53,6 @@ def jax_funcify_Solve(op, **kwargs):
@jax_funcify.register(SolveTriangular) @jax_funcify.register(SolveTriangular)
def jax_funcify_SolveTriangular(op, **kwargs): def jax_funcify_SolveTriangular(op, **kwargs):
lower = op.lower lower = op.lower
trans = op.trans
unit_diagonal = op.unit_diagonal unit_diagonal = op.unit_diagonal
check_finite = op.check_finite check_finite = op.check_finite
...@@ -62,7 +61,7 @@ def jax_funcify_SolveTriangular(op, **kwargs): ...@@ -62,7 +61,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
A, A,
b, b,
lower=lower, lower=lower,
trans=trans, trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal=unit_diagonal, unit_diagonal=unit_diagonal,
check_finite=check_finite, check_finite=check_finite,
) )
......
...@@ -180,7 +180,6 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b ...@@ -180,7 +180,6 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
@numba_funcify.register(SolveTriangular) @numba_funcify.register(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs): def numba_funcify_SolveTriangular(op, node, **kwargs):
trans = bool(op.trans)
lower = op.lower lower = op.lower
unit_diagonal = op.unit_diagonal unit_diagonal = op.unit_diagonal
check_finite = op.check_finite check_finite = op.check_finite
...@@ -208,7 +207,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -208,7 +207,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
res = _solve_triangular( res = _solve_triangular(
a, a,
b, b,
trans=trans, trans=0, # transposing is handled explicitly on the graph, so we never use this argument
lower=lower, lower=lower,
unit_diagonal=unit_diagonal, unit_diagonal=unit_diagonal,
overwrite_b=overwrite_b, overwrite_b=overwrite_b,
......
...@@ -296,13 +296,12 @@ class SolveBase(Op): ...@@ -296,13 +296,12 @@ class SolveBase(Op):
# We need to return (dC/d[inv(A)], dC/db) # We need to return (dC/d[inv(A)], dC/db)
c_bar = output_gradients[0] c_bar = output_gradients[0]
trans_solve_op = type(self)( props_dict = self._props_dict()
**{ props_dict["lower"] = not self.lower
k: (not getattr(self, k) if k == "lower" else getattr(self, k))
for k in self.__props__ solve_op = type(self)(**props_dict)
}
) b_bar = solve_op(A.T, c_bar)
b_bar = trans_solve_op(A.T, c_bar)
# force outer product if vector second input # force outer product if vector second input
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T) A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
...@@ -385,7 +384,6 @@ class SolveTriangular(SolveBase): ...@@ -385,7 +384,6 @@ class SolveTriangular(SolveBase):
"""Solve a system of linear equations.""" """Solve a system of linear equations."""
__props__ = ( __props__ = (
"trans",
"unit_diagonal", "unit_diagonal",
"lower", "lower",
"check_finite", "check_finite",
...@@ -393,11 +391,10 @@ class SolveTriangular(SolveBase): ...@@ -393,11 +391,10 @@ class SolveTriangular(SolveBase):
"overwrite_b", "overwrite_b",
) )
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): def __init__(self, *, unit_diagonal=False, **kwargs):
if kwargs.get("overwrite_a", False): if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for SolverTriangulare") raise ValueError("overwrite_a is not supported for SolverTriangulare")
super().__init__(**kwargs) super().__init__(**kwargs)
self.trans = trans
self.unit_diagonal = unit_diagonal self.unit_diagonal = unit_diagonal
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -406,7 +403,7 @@ class SolveTriangular(SolveBase): ...@@ -406,7 +403,7 @@ class SolveTriangular(SolveBase):
A, A,
b, b,
lower=self.lower, lower=self.lower,
trans=self.trans, trans=0,
unit_diagonal=self.unit_diagonal, unit_diagonal=self.unit_diagonal,
check_finite=self.check_finite, check_finite=self.check_finite,
overwrite_b=self.overwrite_b, overwrite_b=self.overwrite_b,
...@@ -445,9 +442,9 @@ def solve_triangular( ...@@ -445,9 +442,9 @@ def solve_triangular(
Parameters Parameters
---------- ----------
a a: TensorVariable
Square input data Square input data
b b: TensorVariable
Input data for the right hand side. Input data for the right hand side.
lower : bool, optional lower : bool, optional
Use only data contained in the lower triangle of `a`. Default is to use upper triangle. Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
...@@ -468,10 +465,17 @@ def solve_triangular( ...@@ -468,10 +465,17 @@ def solve_triangular(
This will influence how batched dimensions are interpreted. This will influence how batched dimensions are interpreted.
""" """
b_ndim = _default_b_ndim(b, b_ndim) b_ndim = _default_b_ndim(b, b_ndim)
if trans in [1, "T", True]:
a = a.mT
lower = not lower
if trans in [2, "C"]:
a = a.conj().mT
lower = not lower
ret = Blockwise( ret = Blockwise(
SolveTriangular( SolveTriangular(
lower=lower, lower=lower,
trans=trans,
unit_diagonal=unit_diagonal, unit_diagonal=unit_diagonal,
check_finite=check_finite, check_finite=check_finite,
b_ndim=b_ndim, b_ndim=b_ndim,
...@@ -534,6 +538,7 @@ def solve( ...@@ -534,6 +538,7 @@ def solve(
*, *,
assume_a="gen", assume_a="gen",
lower=False, lower=False,
transposed=False,
check_finite=True, check_finite=True,
b_ndim: int | None = None, b_ndim: int | None = None,
): ):
...@@ -564,8 +569,10 @@ def solve( ...@@ -564,8 +569,10 @@ def solve(
b : (..., N, NRHS) array_like b : (..., N, NRHS) array_like
Input data for the right hand side. Input data for the right hand side.
lower : bool, optional lower : bool, optional
If True, only the data contained in the lower triangle of `a`. Default If True, use only the data contained in the lower triangle of `a`. Default
is to use upper triangle. (ignored for ``'gen'``) is to use upper triangle. (ignored for ``'gen'``)
transposed: bool, optional
If True, solves the system A^T x = b. Default is False.
check_finite : bool, optional check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers. Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems Disabling may give a performance gain, but may result in problems
...@@ -577,6 +584,11 @@ def solve( ...@@ -577,6 +584,11 @@ def solve(
This will influence how batched dimensions are interpreted. This will influence how batched dimensions are interpreted.
""" """
b_ndim = _default_b_ndim(b, b_ndim) b_ndim = _default_b_ndim(b, b_ndim)
if transposed:
a = a.mT
lower = not lower
return Blockwise( return Blockwise(
Solve( Solve(
lower=lower, lower=lower,
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
import tests.unittest_tools as utt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor import slinalg as pt_slinalg from pytensor.tensor import slinalg as pt_slinalg
...@@ -103,28 +104,41 @@ def test_jax_basic(): ...@@ -103,28 +104,41 @@ def test_jax_basic():
) )
@pytest.mark.parametrize("check_finite", [False, True]) def test_jax_solve():
@pytest.mark.parametrize("lower", [False, True]) rng = np.random.default_rng(utt.fetch_seed())
@pytest.mark.parametrize("trans", [0, 1, 2])
def test_jax_SolveTriangular(trans, lower, check_finite): A = pt.tensor("A", shape=(5, 5))
x = matrix("x") b = pt.tensor("B", shape=(5, 5))
b = vector("b")
out = pt_slinalg.solve(A, b, lower=False, transposed=False)
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
out = pt_slinalg.solve_triangular(
x,
b,
trans=trans,
lower=lower,
check_finite=check_finite,
)
compare_jax_and_py( compare_jax_and_py(
[x, b], [A, b],
[out], [out],
[ [A_val, b_val],
np.eye(10).astype(config.floatX), )
np.arange(10).astype(config.floatX),
],
def test_jax_SolveTriangular():
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("B", shape=(5, 5))
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
out = pt_slinalg.solve_triangular(
A,
b,
trans=0,
lower=True,
unit_diagonal=False,
) )
compare_jax_and_py([A, b], [out], [A_val, b_val])
def test_jax_block_diag(): def test_jax_block_diag():
......
...@@ -5,7 +5,6 @@ from typing import Literal ...@@ -5,7 +5,6 @@ from typing import Literal
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
from scipy import linalg as scipy_linalg
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
...@@ -26,9 +25,9 @@ def transpose_func(x, trans): ...@@ -26,9 +25,9 @@ def transpose_func(x, trans):
if trans == 0: if trans == 0:
return x return x
if trans == 1: if trans == 1:
return x.conj().T
if trans == 2:
return x.T return x.T
if trans == 2:
return x.conj().T
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -59,18 +58,18 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl ...@@ -59,18 +58,18 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
def A_func(x): def A_func(x):
x = x @ x.conj().T x = x @ x.conj().T
x_tri = scipy_linalg.cholesky(x, lower=lower).astype(dtype) x_tri = pt.linalg.cholesky(x, lower=lower).astype(dtype)
if unit_diag: if unit_diag:
x_tri[np.diag_indices_from(x_tri)] = 1.0 x_tri = pt.fill_diagonal(x_tri, 1.0)
return x_tri.astype(dtype) return x_tri
solve_op = partial( solve_op = partial(
pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag pt.linalg.solve_triangular, lower=lower, trans=trans, unit_diagonal=unit_diag
) )
X = solve_op(A, b) X = solve_op(A_func(A), b)
f = pytensor.function([A, b], X, mode="NUMBA") f = pytensor.function([A, b], X, mode="NUMBA")
A_val = np.random.normal(size=(5, 5)) A_val = np.random.normal(size=(5, 5))
...@@ -80,20 +79,20 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl ...@@ -80,20 +79,20 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
A_val = A_val + np.random.normal(size=(5, 5)) * 1j A_val = A_val + np.random.normal(size=(5, 5)) * 1j
b_val = b_val + np.random.normal(size=b_shape) * 1j b_val = b_val + np.random.normal(size=b_shape) * 1j
X_np = f(A_func(A_val), b_val) X_np = f(A_val.copy(), b_val.copy())
A_val_transformed = transpose_func(A_func(A_val), trans).eval()
test_input = transpose_func(A_func(A_val), trans) np.testing.assert_allclose(
A_val_transformed @ X_np,
ATOL = 1e-8 if floatX.endswith("64") else 1e-4 b_val,
RTOL = 1e-8 if floatX.endswith("64") else 1e-4 atol=1e-8 if floatX.endswith("64") else 1e-4,
rtol=1e-8 if floatX.endswith("64") else 1e-4,
np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL) )
compiled_fgraph = f.maker.fgraph compiled_fgraph = f.maker.fgraph
compare_numba_and_py( compare_numba_and_py(
compiled_fgraph.inputs, compiled_fgraph.inputs,
compiled_fgraph.outputs, compiled_fgraph.outputs,
[A_func(A_val), b_val], [A_val, b_val],
) )
...@@ -145,7 +144,6 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b): ...@@ -145,7 +144,6 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b):
b_test_nb = b_test_py.copy(order="F") b_test_nb = b_test_py.copy(order="F")
op = SolveTriangular( op = SolveTriangular(
trans=0,
unit_diagonal=False, unit_diagonal=False,
lower=False, lower=False,
check_finite=True, check_finite=True,
......
...@@ -214,7 +214,38 @@ def test_solve_raises_on_invalid_A(): ...@@ -214,7 +214,38 @@ def test_solve_raises_on_invalid_A():
Solve(assume_a="test", b_ndim=2) Solve(assume_a="test", b_ndim=2)
solve_test_cases = [
("gen", False, False),
("gen", False, True),
("sym", False, False),
("sym", True, False),
("sym", True, True),
("pos", False, False),
("pos", True, False),
("pos", True, True),
]
solve_test_ids = [
f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
for assume_a, lower, transposed in solve_test_cases
]
class TestSolve(utt.InferShapeTester): class TestSolve(utt.InferShapeTester):
@staticmethod
def A_func(x, assume_a):
if assume_a == "pos":
return x @ x.T
elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
@staticmethod
def T(x, transposed):
if transposed:
return x.T
return x
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) @pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape): def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
...@@ -235,8 +266,12 @@ class TestSolve(utt.InferShapeTester): ...@@ -235,8 +266,12 @@ class TestSolve(utt.InferShapeTester):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
) )
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) @pytest.mark.parametrize(
def test_solve_correctness(self, b_size: tuple[int], assume_a: str): "assume_a, lower, transposed", solve_test_cases, ids=solve_test_ids
)
def test_solve_correctness(
self, b_size: tuple[int], assume_a: str, lower: bool, transposed: bool
):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5)) A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("b", shape=b_size) b = pt.tensor("b", shape=b_size)
...@@ -244,19 +279,18 @@ class TestSolve(utt.InferShapeTester): ...@@ -244,19 +279,18 @@ class TestSolve(utt.InferShapeTester):
A_val = rng.normal(size=(5, 5)).astype(config.floatX) A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_size).astype(config.floatX) b_val = rng.normal(size=b_size).astype(config.floatX)
solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size)) A_func = functools.partial(self.A_func, assume_a=assume_a)
T = functools.partial(self.T, transposed=transposed)
def A_func(x): y = solve(
if assume_a == "pos": A_func(A),
return x @ x.T b,
elif assume_a == "sym": assume_a=assume_a,
return (x + x.T) / 2 lower=lower,
else: transposed=transposed,
return x b_ndim=len(b_size),
)
solve_input_val = A_func(A_val)
y = solve_op(A_func(A), b)
solve_func = pytensor.function([A, b], y) solve_func = pytensor.function([A, b], y)
X_np = solve_func(A_val.copy(), b_val.copy()) X_np = solve_func(A_val.copy(), b_val.copy())
...@@ -264,22 +298,34 @@ class TestSolve(utt.InferShapeTester): ...@@ -264,22 +298,34 @@ class TestSolve(utt.InferShapeTester):
RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4 RTOL = 1e-8 if config.floatX.endswith("64") else 1e-4
np.testing.assert_allclose( np.testing.assert_allclose(
scipy.linalg.solve(solve_input_val, b_val, assume_a=assume_a), scipy.linalg.solve(
A_func(A_val),
b_val,
assume_a=assume_a,
transposed=transposed,
lower=lower,
),
X_np, X_np,
atol=ATOL, atol=ATOL,
rtol=RTOL, rtol=RTOL,
) )
np.testing.assert_allclose(A_func(A_val) @ X_np, b_val, atol=ATOL, rtol=RTOL) np.testing.assert_allclose(T(A_func(A_val)) @ X_np, b_val, atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"] "b_size", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
) )
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) @pytest.mark.parametrize(
"assume_a, lower, transposed",
solve_test_cases,
ids=solve_test_ids,
)
@pytest.mark.skipif( @pytest.mark.skipif(
config.floatX == "float32", reason="Gradients not numerically stable in float32" config.floatX == "float32", reason="Gradients not numerically stable in float32"
) )
def test_solve_gradient(self, b_size: tuple[int], assume_a: str): def test_solve_gradient(
self, b_size: tuple[int], assume_a: str, lower: bool, transposed: bool
):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
eps = 2e-8 if config.floatX == "float64" else None eps = 2e-8 if config.floatX == "float64" else None
...@@ -287,15 +333,8 @@ class TestSolve(utt.InferShapeTester): ...@@ -287,15 +333,8 @@ class TestSolve(utt.InferShapeTester):
A_val = rng.normal(size=(5, 5)).astype(config.floatX) A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=b_size).astype(config.floatX) b_val = rng.normal(size=b_size).astype(config.floatX)
def A_func(x):
if assume_a == "pos":
return x @ x.T
elif assume_a == "sym":
return (x + x.T) / 2
else:
return x
solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size)) solve_op = functools.partial(solve, assume_a=assume_a, b_ndim=len(b_size))
A_func = functools.partial(self.A_func, assume_a=assume_a)
# To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices # To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
# (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included, # (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,
...@@ -307,11 +346,27 @@ class TestSolve(utt.InferShapeTester): ...@@ -307,11 +346,27 @@ class TestSolve(utt.InferShapeTester):
class TestSolveTriangular(utt.InferShapeTester): class TestSolveTriangular(utt.InferShapeTester):
@staticmethod
def A_func(x, lower, unit_diagonal):
x = x @ x.T
x = pt.linalg.cholesky(x, lower=lower)
if unit_diagonal:
x = pt.fill_diagonal(x, 1)
return x
@staticmethod
def T(x, trans):
if trans == 1:
return x.T
elif trans == 2:
return x.conj().T
return x
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) @pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
def test_infer_shape(self, b_shape): def test_infer_shape(self, b_shape):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = matrix() A = matrix()
b_val = np.asarray(rng.random(b_shape), dtype=config.floatX) b_val = rng.random(b_shape).astype(config.floatX)
b = pt.as_tensor_variable(b_val).type() b = pt.as_tensor_variable(b_val).type()
self._compile_and_check( self._compile_and_check(
[A, b], [A, b],
...@@ -324,56 +379,78 @@ class TestSolveTriangular(utt.InferShapeTester): ...@@ -324,56 +379,78 @@ class TestSolveTriangular(utt.InferShapeTester):
warn=False, warn=False,
) )
@pytest.mark.parametrize(
"b_shape", [(5, 1), (5,), (5, 5)], ids=["b_col_vec", "b_vec", "b_matrix"]
)
@pytest.mark.parametrize("lower", [True, False]) @pytest.mark.parametrize("lower", [True, False])
def test_correctness(self, lower): @pytest.mark.parametrize("trans", [0, 1, 2])
@pytest.mark.parametrize("unit_diagonal", [True, False])
def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("b", shape=b_shape)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) A_val = rng.random((5, 5)).astype(config.floatX)
b_val = rng.random(b_shape).astype(config.floatX)
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) A_func = functools.partial(
A_val = np.dot(A_val.transpose(), A_val) self.A_func, lower=lower, unit_diagonal=unit_diagonal
)
C_val = scipy.linalg.cholesky(A_val, lower=lower) x = solve_triangular(
A_func(A),
b,
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
b_ndim=len(b_shape),
)
A = matrix() f = pytensor.function([A, b], x)
b = matrix()
cholesky = Cholesky(lower=lower) x_pt = f(A_val, b_val)
C = cholesky(A) x_sp = scipy.linalg.solve_triangular(
y_lower = solve_triangular(C, b, lower=lower) A_func(A_val).eval(),
lower_solve_func = pytensor.function([C, b], y_lower) b_val,
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
)
assert np.allclose( np.testing.assert_allclose(
scipy.linalg.solve_triangular(C_val, b_val, lower=lower), x_pt,
lower_solve_func(C_val, b_val), x_sp,
atol=1e-8 if config.floatX == "float64" else 1e-4,
rtol=1e-8 if config.floatX == "float64" else 1e-4,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"m, n, lower", "b_shape", [(5, 1), (5,), (5, 5)], ids=["b_col_vec", "b_vec", "b_matrix"]
[
(5, None, False),
(5, None, True),
(4, 2, False),
(4, 2, True),
],
) )
def test_solve_grad(self, m, n, lower): @pytest.mark.parametrize("lower", [True, False])
rng = np.random.default_rng(utt.fetch_seed()) @pytest.mark.parametrize("trans", [0, 1])
@pytest.mark.parametrize("unit_diagonal", [True, False])
def test_solve_triangular_grad(self, b_shape, lower, trans, unit_diagonal):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
# Ensure diagonal elements of `A` are relatively large to avoid rng = np.random.default_rng(utt.fetch_seed())
# numerical precision issues A_val = rng.normal(size=(5, 5)).astype(config.floatX)
A_val = (rng.normal(size=(m, m)) * 0.5 + np.eye(m)).astype(config.floatX) b_val = rng.normal(size=b_shape).astype(config.floatX)
if n is None: A_func = functools.partial(
b_val = rng.normal(size=m).astype(config.floatX) self.A_func, lower=lower, unit_diagonal=unit_diagonal
else: )
b_val = rng.normal(size=(m, n)).astype(config.floatX)
eps = None eps = None
if config.floatX == "float64": if config.floatX == "float64":
eps = 2e-8 eps = 2e-8
solve_op = SolveTriangular(lower=lower, b_ndim=1 if n is None else 2) def solve_op(A, b):
return solve_triangular(
A_func(A), b, lower=lower, trans=trans, unit_diagonal=unit_diagonal
)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论