Unverified 提交 792bd049 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Refactor `Expm` and `Eig`, add jax dispatch for `expm` (#1668)

* `linalg.eig` always returns complex dtype * Update Eig dispatch for Numba, Jax, and Pytorch backends * Clean up `pytensor.linalg.expm` and related tests * Add JAX dispatch for expm * Implement L_op instead of grad in `Eigh` --------- Co-authored-by: 's avatarJesse Grabowski <jesse.grabowski@readyx.com>
上级 27c21cd6
......@@ -10,6 +10,7 @@ from pytensor.tensor.slinalg import (
Cholesky,
CholeskySolve,
Eigvalsh,
Expm,
LUFactor,
PivotToPermutations,
Solve,
......@@ -179,3 +180,11 @@ def jax_funcify_QR(op, **kwargs):
return jax.scipy.linalg.qr(x, mode=mode)
return qr
@jax_funcify.register(Expm)
def jax_funcify_Expm(op, **kwargs):
def expm(x):
return jax.scipy.linalg.expm(x)
return expm
......@@ -76,15 +76,12 @@ def numba_funcify_SLogDet(op, node, **kwargs):
@numba_funcify.register(Eig)
def numba_funcify_Eig(op, node, **kwargs):
out_dtype_1 = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
w_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, w_dtype)
@numba_basic.numba_njit
def eig(x):
out = np.linalg.eig(inputs_cast(x))
return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2))
return np.linalg.eig(inputs_cast(x))
return eig
......
......@@ -16,7 +16,15 @@ from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector
from pytensor.tensor.type import (
Variable,
dvector,
lscalar,
matrix,
scalar,
tensor,
vector,
)
class MatrixPinv(Op):
......@@ -297,37 +305,78 @@ def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
class Eig(Op):
"""
Compute the eigenvalues and right eigenvectors of a square array.
"""
__props__: tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m),(m,m)"
gufunc_spec = ("numpy.linalg.eig", 1, 2)
gufunc_signature = "(m,m)->(m),(m,m)"
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
w = vector(dtype=x.dtype)
v = matrix(dtype=x.dtype)
M, N = x.type.shape
if M is not None and N is not None and M != N:
raise ValueError(
f"Input to Eig must be a square matrix, got static shape: ({M}, {N})"
)
dtype = np.promote_types(x.dtype, np.complex64)
w = tensor(dtype=dtype, shape=(M,))
v = tensor(dtype=dtype, shape=(M, N))
return Apply(self, [x], [w, v])
def perform(self, node, inputs, outputs):
(x,) = inputs
(w, v) = outputs
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
dtype = np.promote_types(x.dtype, np.complex64)
w, v = np.linalg.eig(x)
# If the imaginary part of the eigenvalues is zero, numpy automatically casts them to real. We require
# a statically known return dtype, so we have to cast back to complex to avoid dtype mismatch.
outputs[0][0] = w.astype(dtype, copy=False)
outputs[1][0] = v.astype(dtype, copy=False)
def infer_shape(self, fgraph, node, shapes):
n = shapes[0][0]
(x_shapes,) = shapes
n, _ = x_shapes
return [(n,), (n, n)]
def L_op(self, inputs, outputs, output_grads):
raise NotImplementedError(
"Gradients for Eig is not implemented because it always returns complex values, "
"for which autodiff is not yet supported in PyTensor (PRs welcome :) ).\n"
"If you know that your input has strictly real-valued eigenvalues (e.g. it is a "
"symmetric matrix), use pt.linalg.eigh instead."
)
eig = Blockwise(Eig())
def eig(x: TensorLike):
"""
Return the eigenvalues and right eigenvectors of a square array.
Note that regardless of the input dtype, the eigenvalues and eigenvectors are returned as complex numbers. As a
result, the gradient of this operation is not implemented (because PyTensor does not support autodiff for complex
values yet).
If you know that your input has strictly real-valued eigenvalues (e.g. it is a symmetric matrix), use
`pytensor.tensor.linalg.eigh` instead.
Parameters
----------
x: TensorLike
Square matrix, or array of such matrices
"""
return Blockwise(Eig())(x)
class Eigh(Eig):
"""
Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix.
"""
__props__ = ("UPLO",)
......@@ -354,7 +403,7 @@ class Eigh(Eig):
(w, v) = outputs
w[0], v[0] = np.linalg.eigh(x, self.UPLO)
def grad(self, inputs, g_outputs):
def L_op(self, inputs, outputs, output_grads):
r"""The gradient function should return
.. math:: \sum_n\left(W_n\frac{\partial\,w_n}
......@@ -378,10 +427,9 @@ class Eigh(Eig):
"""
(x,) = inputs
w, v = self(x)
# Replace gradients wrt disconnected variables with
# zeros. This is a work-around for issue #1063.
gw, gv = _zero_disconnected([w, v], g_outputs)
w, v = outputs
gw, gv = _zero_disconnected([w, v], output_grads)
return [EighGrad(self.UPLO)(x, w, v, gw, gv)]
......
......@@ -6,7 +6,6 @@ from typing import Literal, cast
import numpy as np
import scipy.linalg as scipy_linalg
from numpy.exceptions import ComplexWarning
from scipy.linalg import LinAlgError, LinAlgWarning, get_lapack_funcs
import pytensor
......@@ -1304,82 +1303,60 @@ def eigvalsh(a, b, lower=True):
class Expm(Op):
"""
Compute the matrix exponential of a square array.
"""
__props__ = ()
gufunc_signature = "(m,m)->(m,m)"
def make_node(self, A):
A = as_tensor_variable(A)
assert A.ndim == 2
expm = matrix(dtype=A.dtype)
return Apply(
self,
[
A,
],
[
expm,
],
)
expm = matrix(dtype=A.dtype, shape=A.type.shape)
return Apply(self, [A], [expm])
def perform(self, node, inputs, outputs):
(A,) = inputs
(expm,) = outputs
expm[0] = scipy_linalg.expm(A)
def grad(self, inputs, outputs):
def L_op(self, inputs, outputs, output_grads):
# Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
# Kind of... You need to do some algebra from there to arrive at
# this expression.
(A,) = inputs
(g_out,) = outputs
return [ExpmGrad()(A, g_out)]
(_,) = outputs # Outputs not used; included for signature consistency only
(A_bar,) = output_grads
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
w, V = pt.linalg.eig(A)
exp_w = pt.exp(w)
numer = pt.sub.outer(exp_w, exp_w)
denom = pt.sub.outer(w, w)
class ExpmGrad(Op):
"""
Gradient of the matrix exponential of a square array.
# When w_i ≈ w_j, we have a removable singularity in the expression for X, because
# lim b->a (e^a - e^b) / (a - b) = e^a (derivation left for the motivated reader)
X = pt.where(pt.abs(denom) < 1e-8, exp_w, numer / denom)
"""
diag_idx = pt.arange(w.shape[0])
X = X[..., diag_idx, diag_idx].set(exp_w)
__props__ = ()
inner = solve(V, A_bar.T @ V).T
result = solve(V.T, inner * X) @ V.T
def make_node(self, A, gw):
A = as_tensor_variable(A)
assert A.ndim == 2
out = matrix(dtype=A.dtype)
return Apply(
self,
[A, gw],
[
out,
],
)
# At this point, result is always a complex dtype. If the input was real, the output should be
# real as well (and all the imaginary parts are numerical noise)
if A.dtype not in ("complex64", "complex128"):
return [result.real]
return [result]
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
def perform(self, node, inputs, outputs):
# Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
# Kind of... You need to do some algebra from there to arrive at
# this expression.
(A, gA) = inputs
(out,) = outputs
w, V = scipy_linalg.eig(A, right=True)
U = scipy_linalg.inv(V).T
exp_w = np.exp(w)
X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w)
np.fill_diagonal(X, exp_w)
Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T)
with warnings.catch_warnings():
warnings.simplefilter("ignore", ComplexWarning)
out[0] = Y.astype(A.dtype)
expm = Expm()
expm = Blockwise(Expm())
class SolveContinuousLyapunov(Op):
......
......@@ -361,3 +361,12 @@ def test_jax_cho_solve(b_shape, lower):
out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape))
compare_jax_and_py([A, b], [out], [A_val, b_val])
def test_jax_expm():
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor(name="A", shape=(5, 5))
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
out = pt_slinalg.expm(A)
compare_jax_and_py([A], [out], [A_val])
......@@ -4,6 +4,7 @@ import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.tensor import nlinalg
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -51,44 +52,23 @@ y = np.array(
)
@pytest.mark.parametrize(
"x, exc",
[
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(x),
),
None,
),
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(y),
),
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
None,
),
],
)
def test_Eig(x, exc):
x, test_x = x
g = nlinalg.Eig()(x)
@pytest.mark.parametrize("input_dtype", ["float", "int"])
@pytest.mark.parametrize("symmetric", [True, False], ids=["symmetric", "general"])
def test_Eig(input_dtype, symmetric):
x = pt.dmatrix("x")
if input_dtype == "float":
x_val = rng.normal(size=(3, 3)).astype(config.floatX)
else:
x_val = rng.integers(1, 10, size=(3, 3)).astype("int64")
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
if symmetric:
x_val = x_val + x_val.T
g = nlinalg.eig(x)
compare_numba_and_py(
[x],
g,
[test_x],
graph_inputs=[x],
graph_outputs=g,
test_inputs=[x_val],
)
......
......@@ -8,21 +8,29 @@ from pytensor.tensor.type import matrix
from tests.link.pytorch.test_basic import compare_pytorch_and_py
def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)
@pytest.mark.parametrize(
"func",
(pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
(pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
)
def test_lin_alg_no_params(func, matrix_test):
x, test_value = matrix_test
outs = func(x)
def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)
compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn)
def test_eig(matrix_test):
x, test_value = matrix_test
out = pt_nla.eig(x)
compare_pytorch_and_py([x], out, [test_value], assert_fn=assert_fn)
@pytest.mark.parametrize("compute_uv", [True, False])
@pytest.mark.parametrize("full_matrices", [True, False])
def test_svd(compute_uv, full_matrices, matrix_test):
......
......@@ -394,11 +394,12 @@ def test_trace():
class TestEig(utt.InferShapeTester):
op_class = Eig
op = eig
dtype = "float64"
op = staticmethod(eig)
def setup_method(self):
super().setup_method()
self.rng = np.random.default_rng(utt.fetch_seed())
self.A = matrix(dtype=self.dtype)
self.X = np.asarray(self.rng.random((5, 5)), dtype=self.dtype)
......@@ -418,15 +419,54 @@ class TestEig(utt.InferShapeTester):
def test_eval(self):
A = matrix(dtype=self.dtype)
assert [e.eval({A: [[1]]}) for e in self.op(A)] == [[1.0], [[1.0]]]
x = [[0, 1], [1, 0]]
w, v = (e.eval({A: x}) for e in self.op(A))
assert_array_almost_equal(np.dot(x, v), w * v)
fn = function([A], self.op(A))
# Symmetric input (real eigenvalues)
A_val = self.rng.normal(size=(5, 5)).astype(self.dtype)
A_val = A_val + A_val.T
w, v = fn(A_val)
w_np, v_np = np.linalg.eig(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
assert_array_almost_equal(np.dot(A_val, v), w * v)
# Asymmetric input (real eigenvalues)
z = self.rng.normal(size=(5,)) ** 2
A_val = (np.diag(z**0.5)).dot(A_val).dot(np.diag(z ** (-0.5)))
w, v = fn(A_val)
w_np, v_np = np.linalg.eig(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
assert_array_almost_equal(np.dot(A_val, v), w * v)
# Asymmetric input (complex eigenvalues)
A_val = self.rng.normal(size=(5, 5))
w, v = fn(A_val)
w_np, v_np = np.linalg.eig(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
assert_array_almost_equal(np.dot(A_val, v), w * v)
class TestEigh(TestEig):
op = staticmethod(eigh)
def test_eval(self):
A = matrix(dtype=self.dtype)
fn = function([A], self.op(A))
# Symmetric input (real eigenvalues)
A_val = self.rng.normal(size=(5, 5)).astype(self.dtype)
A_val = A_val + A_val.T
w, v = fn(A_val)
w_np, v_np = np.linalg.eigh(A_val)
np.testing.assert_allclose(w, w_np)
np.testing.assert_allclose(v, v_np)
assert_array_almost_equal(np.dot(A_val, v), w * v)
def test_uplo(self):
S = self.S
a = matrix(dtype=self.dtype)
......
......@@ -880,35 +880,26 @@ def test_expm():
np.testing.assert_array_almost_equal(val, ref)
def test_expm_grad_1():
# with symmetric matrix (real eigenvectors)
rng = np.random.default_rng(utt.fetch_seed())
# Always test in float64 for better numerical stability.
@pytest.mark.parametrize(
"mode", ["symmetric", "nonsymmetric_real_eig", "nonsymmetric_complex_eig"][-1:]
)
def test_expm_grad(mode):
rng = np.random.default_rng()
match mode:
case "symmetric":
A = rng.standard_normal((5, 5))
A = A + A.T
utt.verify_grad(expm, [A], rng=rng)
def test_expm_grad_2():
# with non-symmetric matrix with real eigenspecta
rng = np.random.default_rng(utt.fetch_seed())
# Always test in float64 for better numerical stability.
case "nonsymmetric_real_eig":
A = rng.standard_normal((5, 5))
w = rng.standard_normal(5) ** 2
A = (np.diag(w**0.5)).dot(A + A.T).dot(np.diag(w ** (-0.5)))
assert not np.allclose(A, A.T)
utt.verify_grad(expm, [A], rng=rng)
def test_expm_grad_3():
# with non-symmetric matrix (complex eigenvectors)
rng = np.random.default_rng(utt.fetch_seed())
# Always test in float64 for better numerical stability.
case "nonsymmetric_complex_eig":
A = rng.standard_normal((5, 5))
case _:
raise ValueError(f"Invalid mode: {mode}")
utt.verify_grad(expm, [A], rng=rng)
utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5)
def recover_Q(A, X, continuous=True):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论