Unverified 提交 792bd049 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Refactor `Expm` and `Eig`, add jax dispatch for `expm` (#1668)

* `linalg.eig` always returns complex dtype * Update Eig dispatch for Numba, Jax, and Pytorch backends * Clean up `pytensor.linalg.expm` and related tests * Add JAX dispatch for expm * Implement L_op instead of grad in `Eigh` --------- Co-authored-by: 's avatarJesse Grabowski <jesse.grabowski@readyx.com>
上级 27c21cd6
...@@ -10,6 +10,7 @@ from pytensor.tensor.slinalg import ( ...@@ -10,6 +10,7 @@ from pytensor.tensor.slinalg import (
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
Eigvalsh, Eigvalsh,
Expm,
LUFactor, LUFactor,
PivotToPermutations, PivotToPermutations,
Solve, Solve,
...@@ -179,3 +180,11 @@ def jax_funcify_QR(op, **kwargs): ...@@ -179,3 +180,11 @@ def jax_funcify_QR(op, **kwargs):
return jax.scipy.linalg.qr(x, mode=mode) return jax.scipy.linalg.qr(x, mode=mode)
return qr return qr
@jax_funcify.register(Expm)
def jax_funcify_Expm(op, **kwargs):
def expm(x):
return jax.scipy.linalg.expm(x)
return expm
...@@ -76,15 +76,12 @@ def numba_funcify_SLogDet(op, node, **kwargs): ...@@ -76,15 +76,12 @@ def numba_funcify_SLogDet(op, node, **kwargs):
@numba_funcify.register(Eig) @numba_funcify.register(Eig)
def numba_funcify_Eig(op, node, **kwargs): def numba_funcify_Eig(op, node, **kwargs):
out_dtype_1 = node.outputs[0].type.numpy_dtype w_dtype = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, w_dtype)
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
@numba_basic.numba_njit @numba_basic.numba_njit
def eig(x): def eig(x):
out = np.linalg.eig(inputs_cast(x)) return np.linalg.eig(inputs_cast(x))
return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2))
return eig return eig
......
...@@ -16,7 +16,15 @@ from pytensor.tensor import basic as ptb ...@@ -16,7 +16,15 @@ from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector from pytensor.tensor.type import (
Variable,
dvector,
lscalar,
matrix,
scalar,
tensor,
vector,
)
class MatrixPinv(Op): class MatrixPinv(Op):
...@@ -297,37 +305,78 @@ def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]: ...@@ -297,37 +305,78 @@ def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
class Eig(Op): class Eig(Op):
""" """
Compute the eigenvalues and right eigenvectors of a square array. Compute the eigenvalues and right eigenvectors of a square array.
""" """
__props__: tuple[str, ...] = () __props__: tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m),(m,m)"
gufunc_spec = ("numpy.linalg.eig", 1, 2) gufunc_spec = ("numpy.linalg.eig", 1, 2)
gufunc_signature = "(m,m)->(m),(m,m)"
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim == 2 assert x.ndim == 2
w = vector(dtype=x.dtype)
v = matrix(dtype=x.dtype) M, N = x.type.shape
if M is not None and N is not None and M != N:
raise ValueError(
f"Input to Eig must be a square matrix, got static shape: ({M}, {N})"
)
dtype = np.promote_types(x.dtype, np.complex64)
w = tensor(dtype=dtype, shape=(M,))
v = tensor(dtype=dtype, shape=(M, N))
return Apply(self, [x], [w, v]) return Apply(self, [x], [w, v])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
(w, v) = outputs dtype = np.promote_types(x.dtype, np.complex64)
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
w, v = np.linalg.eig(x)
# If the imaginary part of the eigenvalues is zero, numpy automatically casts them to real. We require
# a statically known return dtype, so we have to cast back to complex to avoid dtype mismatch.
outputs[0][0] = w.astype(dtype, copy=False)
outputs[1][0] = v.astype(dtype, copy=False)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
n = shapes[0][0] (x_shapes,) = shapes
n, _ = x_shapes
return [(n,), (n, n)] return [(n,), (n, n)]
def L_op(self, inputs, outputs, output_grads):
raise NotImplementedError(
"Gradients for Eig is not implemented because it always returns complex values, "
"for which autodiff is not yet supported in PyTensor (PRs welcome :) ).\n"
"If you know that your input has strictly real-valued eigenvalues (e.g. it is a "
"symmetric matrix), use pt.linalg.eigh instead."
)
eig = Blockwise(Eig()) def eig(x: TensorLike):
"""
Return the eigenvalues and right eigenvectors of a square array.
Note that regardless of the input dtype, the eigenvalues and eigenvectors are returned as complex numbers. As a
result, the gradient of this operation is not implemented (because PyTensor does not support autodiff for complex
values yet).
If you know that your input has strictly real-valued eigenvalues (e.g. it is a symmetric matrix), use
`pytensor.tensor.linalg.eigh` instead.
Parameters
----------
x: TensorLike
Square matrix, or array of such matrices
"""
return Blockwise(Eig())(x)
class Eigh(Eig): class Eigh(Eig):
""" """
Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix. Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix.
""" """
__props__ = ("UPLO",) __props__ = ("UPLO",)
...@@ -354,7 +403,7 @@ class Eigh(Eig): ...@@ -354,7 +403,7 @@ class Eigh(Eig):
(w, v) = outputs (w, v) = outputs
w[0], v[0] = np.linalg.eigh(x, self.UPLO) w[0], v[0] = np.linalg.eigh(x, self.UPLO)
def grad(self, inputs, g_outputs): def L_op(self, inputs, outputs, output_grads):
r"""The gradient function should return r"""The gradient function should return
.. math:: \sum_n\left(W_n\frac{\partial\,w_n} .. math:: \sum_n\left(W_n\frac{\partial\,w_n}
...@@ -378,10 +427,9 @@ class Eigh(Eig): ...@@ -378,10 +427,9 @@ class Eigh(Eig):
""" """
(x,) = inputs (x,) = inputs
w, v = self(x) w, v = outputs
# Replace gradients wrt disconnected variables with gw, gv = _zero_disconnected([w, v], output_grads)
# zeros. This is a work-around for issue #1063.
gw, gv = _zero_disconnected([w, v], g_outputs)
return [EighGrad(self.UPLO)(x, w, v, gw, gv)] return [EighGrad(self.UPLO)(x, w, v, gw, gv)]
......
...@@ -6,7 +6,6 @@ from typing import Literal, cast ...@@ -6,7 +6,6 @@ from typing import Literal, cast
import numpy as np import numpy as np
import scipy.linalg as scipy_linalg import scipy.linalg as scipy_linalg
from numpy.exceptions import ComplexWarning
from scipy.linalg import LinAlgError, LinAlgWarning, get_lapack_funcs from scipy.linalg import LinAlgError, LinAlgWarning, get_lapack_funcs
import pytensor import pytensor
...@@ -1304,82 +1303,60 @@ def eigvalsh(a, b, lower=True): ...@@ -1304,82 +1303,60 @@ def eigvalsh(a, b, lower=True):
class Expm(Op): class Expm(Op):
""" """
Compute the matrix exponential of a square array. Compute the matrix exponential of a square array.
""" """
__props__ = () __props__ = ()
gufunc_signature = "(m,m)->(m,m)"
def make_node(self, A): def make_node(self, A):
A = as_tensor_variable(A) A = as_tensor_variable(A)
assert A.ndim == 2 assert A.ndim == 2
expm = matrix(dtype=A.dtype)
return Apply( expm = matrix(dtype=A.dtype, shape=A.type.shape)
self,
[ return Apply(self, [A], [expm])
A,
],
[
expm,
],
)
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(A,) = inputs (A,) = inputs
(expm,) = outputs (expm,) = outputs
expm[0] = scipy_linalg.expm(A) expm[0] = scipy_linalg.expm(A)
def grad(self, inputs, outputs): def L_op(self, inputs, outputs, output_grads):
# Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
# Kind of... You need to do some algebra from there to arrive at
# this expression.
(A,) = inputs (A,) = inputs
(g_out,) = outputs (_,) = outputs # Outputs not used; included for signature consistency only
return [ExpmGrad()(A, g_out)] (A_bar,) = output_grads
def infer_shape(self, fgraph, node, shapes): w, V = pt.linalg.eig(A)
return [shapes[0]]
exp_w = pt.exp(w)
numer = pt.sub.outer(exp_w, exp_w)
denom = pt.sub.outer(w, w)
class ExpmGrad(Op): # When w_i ≈ w_j, we have a removable singularity in the expression for X, because
""" # lim b->a (e^a - e^b) / (a - b) = e^a (derivation left for the motivated reader)
Gradient of the matrix exponential of a square array. X = pt.where(pt.abs(denom) < 1e-8, exp_w, numer / denom)
""" diag_idx = pt.arange(w.shape[0])
X = X[..., diag_idx, diag_idx].set(exp_w)
__props__ = () inner = solve(V, A_bar.T @ V).T
result = solve(V.T, inner * X) @ V.T
def make_node(self, A, gw): # At this point, result is always a complex dtype. If the input was real, the output should be
A = as_tensor_variable(A) # real as well (and all the imaginary parts are numerical noise)
assert A.ndim == 2 if A.dtype not in ("complex64", "complex128"):
out = matrix(dtype=A.dtype) return [result.real]
return Apply(
self, return [result]
[A, gw],
[
out,
],
)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
def perform(self, node, inputs, outputs):
# Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
# Kind of... You need to do some algebra from there to arrive at
# this expression.
(A, gA) = inputs
(out,) = outputs
w, V = scipy_linalg.eig(A, right=True)
U = scipy_linalg.inv(V).T
exp_w = np.exp(w)
X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w)
np.fill_diagonal(X, exp_w)
Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T)
with warnings.catch_warnings():
warnings.simplefilter("ignore", ComplexWarning)
out[0] = Y.astype(A.dtype)
expm = Expm() expm = Blockwise(Expm())
class SolveContinuousLyapunov(Op): class SolveContinuousLyapunov(Op):
......
...@@ -361,3 +361,12 @@ def test_jax_cho_solve(b_shape, lower): ...@@ -361,3 +361,12 @@ def test_jax_cho_solve(b_shape, lower):
out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape)) out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape))
compare_jax_and_py([A, b], [out], [A_val, b_val]) compare_jax_and_py([A, b], [out], [A_val, b_val])
def test_jax_expm():
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor(name="A", shape=(5, 5))
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
out = pt_slinalg.expm(A)
compare_jax_and_py([A], [out], [A_val])
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config
from pytensor.tensor import nlinalg from pytensor.tensor import nlinalg
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
...@@ -51,45 +52,24 @@ y = np.array( ...@@ -51,45 +52,24 @@ y = np.array(
) )
@pytest.mark.parametrize( @pytest.mark.parametrize("input_dtype", ["float", "int"])
"x, exc", @pytest.mark.parametrize("symmetric", [True, False], ids=["symmetric", "general"])
[ def test_Eig(input_dtype, symmetric):
( x = pt.dmatrix("x")
( if input_dtype == "float":
pt.dmatrix(), x_val = rng.normal(size=(3, 3)).astype(config.floatX)
(lambda x: x.T.dot(x))(x), else:
), x_val = rng.integers(1, 10, size=(3, 3)).astype("int64")
None,
), if symmetric:
( x_val = x_val + x_val.T
(
pt.dmatrix(), g = nlinalg.eig(x)
(lambda x: x.T.dot(x))(y), compare_numba_and_py(
), graph_inputs=[x],
None, graph_outputs=g,
), test_inputs=[x_val],
( )
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
None,
),
],
)
def test_Eig(x, exc):
x, test_x = x
g = nlinalg.Eig()(x)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
g,
[test_x],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -8,21 +8,29 @@ from pytensor.tensor.type import matrix ...@@ -8,21 +8,29 @@ from pytensor.tensor.type import matrix
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"func", "func",
(pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det), (pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
) )
def test_lin_alg_no_params(func, matrix_test): def test_lin_alg_no_params(func, matrix_test):
x, test_value = matrix_test x, test_value = matrix_test
outs = func(x) outs = func(x)
def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)
compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn) compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn)
def test_eig(matrix_test):
x, test_value = matrix_test
out = pt_nla.eig(x)
compare_pytorch_and_py([x], out, [test_value], assert_fn=assert_fn)
@pytest.mark.parametrize("compute_uv", [True, False]) @pytest.mark.parametrize("compute_uv", [True, False])
@pytest.mark.parametrize("full_matrices", [True, False]) @pytest.mark.parametrize("full_matrices", [True, False])
def test_svd(compute_uv, full_matrices, matrix_test): def test_svd(compute_uv, full_matrices, matrix_test):
......
...@@ -394,11 +394,12 @@ def test_trace(): ...@@ -394,11 +394,12 @@ def test_trace():
class TestEig(utt.InferShapeTester): class TestEig(utt.InferShapeTester):
op_class = Eig op_class = Eig
op = eig
dtype = "float64" dtype = "float64"
op = staticmethod(eig)
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
self.rng = np.random.default_rng(utt.fetch_seed()) self.rng = np.random.default_rng(utt.fetch_seed())
self.A = matrix(dtype=self.dtype) self.A = matrix(dtype=self.dtype)
self.X = np.asarray(self.rng.random((5, 5)), dtype=self.dtype) self.X = np.asarray(self.rng.random((5, 5)), dtype=self.dtype)
...@@ -418,15 +419,54 @@ class TestEig(utt.InferShapeTester): ...@@ -418,15 +419,54 @@ class TestEig(utt.InferShapeTester):
def test_eval(self): def test_eval(self):
A = matrix(dtype=self.dtype) A = matrix(dtype=self.dtype)
assert [e.eval({A: [[1]]}) for e in self.op(A)] == [[1.0], [[1.0]]] fn = function([A], self.op(A))
x = [[0, 1], [1, 0]]
w, v = (e.eval({A: x}) for e in self.op(A)) # Symmetric input (real eigenvalues)
assert_array_almost_equal(np.dot(x, v), w * v) A_val = self.rng.normal(size=(5, 5)).astype(self.dtype)
A_val = A_val + A_val.T
w, v = fn(A_val)
w_np, v_np = np.linalg.eig(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
assert_array_almost_equal(np.dot(A_val, v), w * v)
# Asymmetric input (real eigenvalues)
z = self.rng.normal(size=(5,)) ** 2
A_val = (np.diag(z**0.5)).dot(A_val).dot(np.diag(z ** (-0.5)))
w, v = fn(A_val)
w_np, v_np = np.linalg.eig(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
assert_array_almost_equal(np.dot(A_val, v), w * v)
# Asymmetric input (complex eigenvalues)
A_val = self.rng.normal(size=(5, 5))
w, v = fn(A_val)
w_np, v_np = np.linalg.eig(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
assert_array_almost_equal(np.dot(A_val, v), w * v)
class TestEigh(TestEig): class TestEigh(TestEig):
op = staticmethod(eigh) op = staticmethod(eigh)
def test_eval(self):
A = matrix(dtype=self.dtype)
fn = function([A], self.op(A))
# Symmetric input (real eigenvalues)
A_val = self.rng.normal(size=(5, 5)).astype(self.dtype)
A_val = A_val + A_val.T
w, v = fn(A_val)
w_np, v_np = np.linalg.eigh(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
assert_array_almost_equal(np.dot(A_val, v), w * v)
def test_uplo(self): def test_uplo(self):
S = self.S S = self.S
a = matrix(dtype=self.dtype) a = matrix(dtype=self.dtype)
......
...@@ -880,35 +880,26 @@ def test_expm(): ...@@ -880,35 +880,26 @@ def test_expm():
np.testing.assert_array_almost_equal(val, ref) np.testing.assert_array_almost_equal(val, ref)
def test_expm_grad_1(): @pytest.mark.parametrize(
# with symmetric matrix (real eigenvectors) "mode", ["symmetric", "nonsymmetric_real_eig", "nonsymmetric_complex_eig"][-1:]
rng = np.random.default_rng(utt.fetch_seed()) )
# Always test in float64 for better numerical stability. def test_expm_grad(mode):
A = rng.standard_normal((5, 5)) rng = np.random.default_rng()
A = A + A.T
match mode:
utt.verify_grad(expm, [A], rng=rng) case "symmetric":
A = rng.standard_normal((5, 5))
A = A + A.T
def test_expm_grad_2(): case "nonsymmetric_real_eig":
# with non-symmetric matrix with real eigenspecta A = rng.standard_normal((5, 5))
rng = np.random.default_rng(utt.fetch_seed()) w = rng.standard_normal(5) ** 2
# Always test in float64 for better numerical stability. A = (np.diag(w**0.5)).dot(A + A.T).dot(np.diag(w ** (-0.5)))
A = rng.standard_normal((5, 5)) case "nonsymmetric_complex_eig":
w = rng.standard_normal(5) ** 2 A = rng.standard_normal((5, 5))
A = (np.diag(w**0.5)).dot(A + A.T).dot(np.diag(w ** (-0.5))) case _:
assert not np.allclose(A, A.T) raise ValueError(f"Invalid mode: {mode}")
utt.verify_grad(expm, [A], rng=rng) utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5)
def test_expm_grad_3():
# with non-symmetric matrix (complex eigenvectors)
rng = np.random.default_rng(utt.fetch_seed())
# Always test in float64 for better numerical stability.
A = rng.standard_normal((5, 5))
utt.verify_grad(expm, [A], rng=rng)
def recover_Q(A, X, continuous=True): def recover_Q(A, X, continuous=True):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论