提交 28f26483 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Refactor `nlinalg.norm` to match `np.linalg.norm`

Expand TestNorm test coverage
上级 1a0d12d8
import warnings
from functools import partial
from typing import Callable, Literal, Optional, Union
import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
from pytensor import scalar as ps
from pytensor.gradient import DisconnectedType
......@@ -688,41 +690,204 @@ def matrix_power(M, n):
return result
def norm(x, ord):
x = as_tensor_variable(x)
def _multi_svd_norm(
x: ptb.TensorVariable, row_axis: int, col_axis: int, reduce_op: Callable
):
"""Compute a function of the singular values of the 2-D matrices in `x`.
This is a private utility function used by `pytensor.tensor.nlinalg.norm()`.
Copied from `np.linalg._multi_svd_norm`.
Parameters
----------
x : TensorVariable
Input tensor.
row_axis, col_axis : int
The axes of `x` that hold the 2-D matrices.
reduce_op : callable
Reduction op. Should be one of `pt.min`, `pt.max`, or `pt.sum`
Returns
-------
result : float or ndarray
If `x` is 2-D, the return values is a float.
Otherwise, it is an array with ``x.ndim - 2`` dimensions.
The return values are either the minimum or maximum or sum of the
singular values of the matrices, depending on whether `op`
is `pt.amin` or `pt.amax` or `pt.sum`.
"""
y = ptb.moveaxis(x, (row_axis, col_axis), (-2, -1))
result = reduce_op(svd(y, compute_uv=False), axis=-1)
return result
VALID_ORD = Literal["fro", "f", "nuc", "inf", "-inf", 0, 1, -1, 2, -2]
def norm(
x: ptb.TensorVariable,
ord: Optional[Union[float, VALID_ORD]] = None,
axis: Optional[Union[int, tuple[int, ...]]] = None,
keepdims: bool = False,
):
"""
Matrix or vector norm.
Parameters
----------
x: TensorVariable
Tensor to take norm of.
ord: float, str or int, optional
Order of norm. If `ord` is a str, it must be one of the following:
- 'fro' or 'f' : Frobenius norm
- 'nuc' : nuclear norm
- 'inf' : Infinity norm
- '-inf' : Negative infinity norm
If an integer, order can be one of -2, -1, 0, 1, or 2.
Otherwise `ord` must be a float.
Default is the Frobenius (L2) norm.
axis: tuple of int, optional
Axes over which to compute the norm. If None, norm of entire matrix (or vector) is computed. Row or column
norms can be computed by passing a single integer; this will treat a matrix like a batch of vectors.
keepdims: bool
If True, dummy axes will be inserted into the output so that norm.dnim == x.dnim. Default is False.
Returns
-------
TensorVariable
Norm of `x` along axes specified by `axis`.
Notes
-----
Batched dimensions are supported to the left of the core dimensions. For example, if `x` is a 3D tensor with
shape (2, 3, 4), then `norm(x)` will compute the norm of each 3x4 matrix in the batch.
If the input is a 2D tensor and should be treated as a batch of vectors, the `axis` argument must be specified.
"""
x = ptb.as_tensor_variable(x)
ndim = x.ndim
if ndim == 0:
raise ValueError("'axis' entry is out of bounds.")
elif ndim == 1:
if ord is None:
return ptm.sum(x**2) ** 0.5
elif ord == "inf":
return ptm.max(abs(x))
elif ord == "-inf":
return ptm.min(abs(x))
core_ndim = min(2, ndim)
batch_ndim = ndim - core_ndim
if axis is None:
# Handle some common cases first. These can be computed more quickly than the default SVD way, so we always
# want to check for them.
if (
(ord is None)
or (ord in ("f", "fro") and core_ndim == 2)
or (ord == 2 and core_ndim == 1)
):
x = x.reshape(tuple(x.shape[:-2]) + (-1,) + (1,) * (core_ndim - 1))
batch_T_dim_order = tuple(range(batch_ndim)) + tuple(
range(batch_ndim + core_ndim - 1, batch_ndim - 1, -1)
)
if x.dtype.startswith("complex"):
x_real = x.real # type: ignore
x_imag = x.imag # type: ignore
sqnorm = (
ptb.transpose(x_real, batch_T_dim_order) @ x_real
+ ptb.transpose(x_imag, batch_T_dim_order) @ x_imag
)
else:
sqnorm = ptb.transpose(x, batch_T_dim_order) @ x
ret = ptm.sqrt(sqnorm).squeeze()
if keepdims:
ret = ptb.shape_padright(ret, core_ndim)
return ret
# No special computation to exploit -- set default axis before continuing
axis = tuple(range(core_ndim))
elif not isinstance(axis, tuple):
try:
axis = int(axis)
except Exception as e:
raise TypeError(
"'axis' must be None, an integer, or a tuple of integers"
) from e
axis = (axis,)
if len(axis) == 1:
# Vector norms
if ord in [None, "fro", "f"] and (core_ndim == 2):
# This is here to catch the case where X is a 2D tensor but the user wants to treat it as a batch of
# vectors. Other vector norms will work fine in this case.
ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis, keepdims=keepdims))
elif (ord == "inf") or (ord == np.inf):
ret = ptm.max(ptm.abs(x), axis=axis, keepdims=keepdims)
elif (ord == "-inf") or (ord == -np.inf):
ret = ptm.min(ptm.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return x[x.nonzero()].shape[0]
ret = ptm.neq(x, 0).sum(axis=axis, keepdims=keepdims)
elif ord == 1:
ret = ptm.sum(ptm.abs(x), axis=axis, keepdims=keepdims)
elif isinstance(ord, str):
raise ValueError(f"Invalid norm order '{ord}' for vectors")
else:
try:
z = ptm.sum(abs(x**ord)) ** (1.0 / ord)
except TypeError:
raise ValueError("Invalid norm order for vectors.")
return z
elif ndim == 2:
if ord is None or ord == "fro":
return ptm.sum(abs(x**2)) ** (0.5)
elif ord == "inf":
return ptm.max(ptm.sum(abs(x), 1))
elif ord == "-inf":
return ptm.min(ptm.sum(abs(x), 1))
ret = ptm.sum(ptm.abs(x) ** ord, axis=axis, keepdims=keepdims)
ret **= ptm.reciprocal(ord)
return ret
elif len(axis) == 2:
# Matrix norms
row_axis, col_axis = (
batch_ndim + x for x in normalize_axis_tuple(axis, core_ndim)
)
axis = (row_axis, col_axis)
if ord in [None, "fro", "f"]:
ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis))
elif (ord == "inf") or (ord == np.inf):
if row_axis > col_axis:
row_axis -= 1
ret = ptm.max(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis)
elif (ord == "-inf") or (ord == -np.inf):
if row_axis > col_axis:
row_axis -= 1
ret = ptm.min(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis)
elif ord == 1:
return ptm.max(ptm.sum(abs(x), 0))
if col_axis > row_axis:
col_axis -= 1
ret = ptm.max(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis)
elif ord == -1:
return ptm.min(ptm.sum(abs(x), 0))
if col_axis > row_axis:
col_axis -= 1
ret = ptm.min(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis)
elif ord == 2:
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.max)
elif ord == -2:
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.min)
elif ord == "nuc":
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.sum)
else:
raise ValueError(0)
elif ndim > 2:
raise NotImplementedError("We don't support norm with ndim > 2")
raise ValueError(f"Invalid norm order for matrices: {ord}")
if keepdims:
ret = ptb.expand_dims(ret, axis)
return ret
else:
raise ValueError(
f"Cannot compute norm when core_dims < 1 or core_dims > 3, found: core_dims = {core_ndim}"
)
class TensorInv(Op):
......
......@@ -3,7 +3,6 @@ from functools import partial
import numpy as np
import numpy.linalg
import pytest
from numpy import inf
from numpy.testing import assert_array_almost_equal
import pytensor
......@@ -463,44 +462,82 @@ class TestMatrixPower:
f(a)
class TestNormTests:
class TestNorm:
def test_wrong_type_of_ord_for_vector(self):
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Invalid norm order 'fro' for vectors"):
norm([2, 1], "fro")
def test_wrong_type_of_ord_for_matrix(self):
with pytest.raises(ValueError):
norm([[2, 1], [3, 4]], 0)
ord = 0
with pytest.raises(ValueError, match=f"Invalid norm order for matrices: {ord}"):
norm([[2, 1], [3, 4]], ord)
def test_non_tensorial_input(self):
with pytest.raises(ValueError):
norm(3, None)
with pytest.raises(
ValueError,
match="Cannot compute norm when core_dims < 1 or core_dims > 3, found: core_dims = 0",
):
norm(3, ord=2)
def test_invalid_axis_input(self):
axis = scalar("i", dtype="int")
with pytest.raises(
TypeError, match="'axis' must be None, an integer, or a tuple of integers"
):
norm([[1, 2], [3, 4]], axis=axis)
def test_tensor_input(self):
res = norm(np.random.random((3, 4, 5)), None)
assert res.shape.eval() == (3,)
@pytest.mark.parametrize(
"ord",
[None, np.inf, -np.inf, 1, -1, 2, -2],
ids=["None", "inf", "-inf", "1", "-1", "2", "-2"],
)
@pytest.mark.parametrize("core_dims", [(4,), (4, 3)], ids=["vector", "matrix"])
@pytest.mark.parametrize("batch_dims", [(), (2,)], ids=["no_batch", "batch"])
@pytest.mark.parametrize("test_imag", [True, False], ids=["complex", "real"])
@pytest.mark.parametrize(
"keepdims", [True, False], ids=["keep_dims=True", "keep_dims=False"]
)
def test_numpy_compare(
self,
ord: float,
core_dims: tuple[int, ...],
batch_dims: tuple[int, ...],
test_imag: bool,
keepdims: bool,
axis=None,
):
is_matrix = len(core_dims) == 2
has_batch = len(batch_dims) > 0
if ord in [np.inf, -np.inf] and not is_matrix:
pytest.skip("Infinity norm not defined for vectors")
if test_imag and is_matrix and ord == -2:
pytest.skip("Complex matrices not supported")
if has_batch and not is_matrix:
# Handle batched vectors by row-normalizing a matrix
axis = (-1,)
def test_numpy_compare(self):
rng = np.random.default_rng(utt.fetch_seed())
M = matrix("A", dtype=config.floatX)
V = vector("V", dtype=config.floatX)
if test_imag:
x_real, x_imag = rng.standard_normal((2, *batch_dims, *core_dims)).astype(
config.floatX
)
dtype = "complex128" if config.floatX.endswith("64") else "complex64"
X = (x_real + 1j * x_imag).astype(dtype)
else:
X = rng.standard_normal(batch_dims + core_dims).astype(config.floatX)
a = rng.random((4, 4)).astype(config.floatX)
b = rng.random(4).astype(config.floatX)
if batch_dims == ():
np_norm = np.linalg.norm(X, ord=ord, axis=axis, keepdims=keepdims)
else:
np_norm = np.stack(
[np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) for x in X]
)
A = (
[None, "fro", "inf", "-inf", 1, -1, None, "inf", "-inf", 0, 1, -1, 2, -2],
[M, M, M, M, M, M, V, V, V, V, V, V, V, V],
[a, a, a, a, a, a, b, b, b, b, b, b, b, b],
[None, "fro", inf, -inf, 1, -1, None, inf, -inf, 0, 1, -1, 2, -2],
)
pt_norm = norm(X, ord=ord, axis=axis, keepdims=keepdims)
f = function([], pt_norm, mode="FAST_COMPILE")
for i in range(0, 14):
f = function([A[1][i]], norm(A[1][i], A[0][i]))
t_n = f(A[2][i])
n_n = np.linalg.norm(A[2][i], A[3][i])
assert _allclose(n_n, t_n)
utt.assert_allclose(np_norm, f())
class TestTensorInv(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论