提交 28f26483 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Refactor `nlinalg.norm` to match `np.linalg.norm`

Expand TestNorm test coverage
上级 1a0d12d8
import warnings import warnings
from functools import partial from functools import partial
from typing import Callable, Literal, Optional, Union
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
...@@ -688,41 +690,204 @@ def matrix_power(M, n): ...@@ -688,41 +690,204 @@ def matrix_power(M, n):
return result return result
def norm(x, ord): def _multi_svd_norm(
x = as_tensor_variable(x) x: ptb.TensorVariable, row_axis: int, col_axis: int, reduce_op: Callable
):
"""Compute a function of the singular values of the 2-D matrices in `x`.
This is a private utility function used by `pytensor.tensor.nlinalg.norm()`.
Copied from `np.linalg._multi_svd_norm`.
Parameters
----------
x : TensorVariable
Input tensor.
row_axis, col_axis : int
The axes of `x` that hold the 2-D matrices.
reduce_op : callable
Reduction op. Should be one of `pt.min`, `pt.max`, or `pt.sum`
Returns
-------
result : float or ndarray
If `x` is 2-D, the return values is a float.
Otherwise, it is an array with ``x.ndim - 2`` dimensions.
The return values are either the minimum or maximum or sum of the
singular values of the matrices, depending on whether `op`
is `pt.amin` or `pt.amax` or `pt.sum`.
"""
y = ptb.moveaxis(x, (row_axis, col_axis), (-2, -1))
result = reduce_op(svd(y, compute_uv=False), axis=-1)
return result
VALID_ORD = Literal["fro", "f", "nuc", "inf", "-inf", 0, 1, -1, 2, -2]
def norm(
x: ptb.TensorVariable,
ord: Optional[Union[float, VALID_ORD]] = None,
axis: Optional[Union[int, tuple[int, ...]]] = None,
keepdims: bool = False,
):
"""
Matrix or vector norm.
Parameters
----------
x: TensorVariable
Tensor to take norm of.
ord: float, str or int, optional
Order of norm. If `ord` is a str, it must be one of the following:
- 'fro' or 'f' : Frobenius norm
- 'nuc' : nuclear norm
- 'inf' : Infinity norm
- '-inf' : Negative infinity norm
If an integer, order can be one of -2, -1, 0, 1, or 2.
Otherwise `ord` must be a float.
Default is the Frobenius (L2) norm.
axis: tuple of int, optional
Axes over which to compute the norm. If None, norm of entire matrix (or vector) is computed. Row or column
norms can be computed by passing a single integer; this will treat a matrix like a batch of vectors.
keepdims: bool
If True, dummy axes will be inserted into the output so that norm.dnim == x.dnim. Default is False.
Returns
-------
TensorVariable
Norm of `x` along axes specified by `axis`.
Notes
-----
Batched dimensions are supported to the left of the core dimensions. For example, if `x` is a 3D tensor with
shape (2, 3, 4), then `norm(x)` will compute the norm of each 3x4 matrix in the batch.
If the input is a 2D tensor and should be treated as a batch of vectors, the `axis` argument must be specified.
"""
x = ptb.as_tensor_variable(x)
ndim = x.ndim ndim = x.ndim
if ndim == 0: core_ndim = min(2, ndim)
raise ValueError("'axis' entry is out of bounds.") batch_ndim = ndim - core_ndim
elif ndim == 1:
if ord is None: if axis is None:
return ptm.sum(x**2) ** 0.5 # Handle some common cases first. These can be computed more quickly than the default SVD way, so we always
elif ord == "inf": # want to check for them.
return ptm.max(abs(x)) if (
elif ord == "-inf": (ord is None)
return ptm.min(abs(x)) or (ord in ("f", "fro") and core_ndim == 2)
or (ord == 2 and core_ndim == 1)
):
x = x.reshape(tuple(x.shape[:-2]) + (-1,) + (1,) * (core_ndim - 1))
batch_T_dim_order = tuple(range(batch_ndim)) + tuple(
range(batch_ndim + core_ndim - 1, batch_ndim - 1, -1)
)
if x.dtype.startswith("complex"):
x_real = x.real # type: ignore
x_imag = x.imag # type: ignore
sqnorm = (
ptb.transpose(x_real, batch_T_dim_order) @ x_real
+ ptb.transpose(x_imag, batch_T_dim_order) @ x_imag
)
else:
sqnorm = ptb.transpose(x, batch_T_dim_order) @ x
ret = ptm.sqrt(sqnorm).squeeze()
if keepdims:
ret = ptb.shape_padright(ret, core_ndim)
return ret
# No special computation to exploit -- set default axis before continuing
axis = tuple(range(core_ndim))
elif not isinstance(axis, tuple):
try:
axis = int(axis)
except Exception as e:
raise TypeError(
"'axis' must be None, an integer, or a tuple of integers"
) from e
axis = (axis,)
if len(axis) == 1:
# Vector norms
if ord in [None, "fro", "f"] and (core_ndim == 2):
# This is here to catch the case where X is a 2D tensor but the user wants to treat it as a batch of
# vectors. Other vector norms will work fine in this case.
ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis, keepdims=keepdims))
elif (ord == "inf") or (ord == np.inf):
ret = ptm.max(ptm.abs(x), axis=axis, keepdims=keepdims)
elif (ord == "-inf") or (ord == -np.inf):
ret = ptm.min(ptm.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0: elif ord == 0:
return x[x.nonzero()].shape[0] ret = ptm.neq(x, 0).sum(axis=axis, keepdims=keepdims)
elif ord == 1:
ret = ptm.sum(ptm.abs(x), axis=axis, keepdims=keepdims)
elif isinstance(ord, str):
raise ValueError(f"Invalid norm order '{ord}' for vectors")
else: else:
try: ret = ptm.sum(ptm.abs(x) ** ord, axis=axis, keepdims=keepdims)
z = ptm.sum(abs(x**ord)) ** (1.0 / ord) ret **= ptm.reciprocal(ord)
except TypeError:
raise ValueError("Invalid norm order for vectors.") return ret
return z
elif ndim == 2: elif len(axis) == 2:
if ord is None or ord == "fro": # Matrix norms
return ptm.sum(abs(x**2)) ** (0.5) row_axis, col_axis = (
elif ord == "inf": batch_ndim + x for x in normalize_axis_tuple(axis, core_ndim)
return ptm.max(ptm.sum(abs(x), 1)) )
elif ord == "-inf": axis = (row_axis, col_axis)
return ptm.min(ptm.sum(abs(x), 1))
if ord in [None, "fro", "f"]:
ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis))
elif (ord == "inf") or (ord == np.inf):
if row_axis > col_axis:
row_axis -= 1
ret = ptm.max(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis)
elif (ord == "-inf") or (ord == -np.inf):
if row_axis > col_axis:
row_axis -= 1
ret = ptm.min(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis)
elif ord == 1: elif ord == 1:
return ptm.max(ptm.sum(abs(x), 0)) if col_axis > row_axis:
col_axis -= 1
ret = ptm.max(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis)
elif ord == -1: elif ord == -1:
return ptm.min(ptm.sum(abs(x), 0)) if col_axis > row_axis:
col_axis -= 1
ret = ptm.min(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis)
elif ord == 2:
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.max)
elif ord == -2:
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.min)
elif ord == "nuc":
ret = _multi_svd_norm(x, row_axis, col_axis, ptm.sum)
else: else:
raise ValueError(0) raise ValueError(f"Invalid norm order for matrices: {ord}")
elif ndim > 2:
raise NotImplementedError("We don't support norm with ndim > 2") if keepdims:
ret = ptb.expand_dims(ret, axis)
return ret
else:
raise ValueError(
f"Cannot compute norm when core_dims < 1 or core_dims > 3, found: core_dims = {core_ndim}"
)
class TensorInv(Op): class TensorInv(Op):
......
...@@ -3,7 +3,6 @@ from functools import partial ...@@ -3,7 +3,6 @@ from functools import partial
import numpy as np import numpy as np
import numpy.linalg import numpy.linalg
import pytest import pytest
from numpy import inf
from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_almost_equal
import pytensor import pytensor
...@@ -463,44 +462,82 @@ class TestMatrixPower: ...@@ -463,44 +462,82 @@ class TestMatrixPower:
f(a) f(a)
class TestNormTests: class TestNorm:
def test_wrong_type_of_ord_for_vector(self): def test_wrong_type_of_ord_for_vector(self):
with pytest.raises(ValueError): with pytest.raises(ValueError, match="Invalid norm order 'fro' for vectors"):
norm([2, 1], "fro") norm([2, 1], "fro")
def test_wrong_type_of_ord_for_matrix(self): def test_wrong_type_of_ord_for_matrix(self):
with pytest.raises(ValueError): ord = 0
norm([[2, 1], [3, 4]], 0) with pytest.raises(ValueError, match=f"Invalid norm order for matrices: {ord}"):
norm([[2, 1], [3, 4]], ord)
def test_non_tensorial_input(self): def test_non_tensorial_input(self):
with pytest.raises(ValueError): with pytest.raises(
norm(3, None) ValueError,
match="Cannot compute norm when core_dims < 1 or core_dims > 3, found: core_dims = 0",
):
norm(3, ord=2)
def test_invalid_axis_input(self):
axis = scalar("i", dtype="int")
with pytest.raises(
TypeError, match="'axis' must be None, an integer, or a tuple of integers"
):
norm([[1, 2], [3, 4]], axis=axis)
def test_tensor_input(self): @pytest.mark.parametrize(
res = norm(np.random.random((3, 4, 5)), None) "ord",
assert res.shape.eval() == (3,) [None, np.inf, -np.inf, 1, -1, 2, -2],
ids=["None", "inf", "-inf", "1", "-1", "2", "-2"],
)
@pytest.mark.parametrize("core_dims", [(4,), (4, 3)], ids=["vector", "matrix"])
@pytest.mark.parametrize("batch_dims", [(), (2,)], ids=["no_batch", "batch"])
@pytest.mark.parametrize("test_imag", [True, False], ids=["complex", "real"])
@pytest.mark.parametrize(
"keepdims", [True, False], ids=["keep_dims=True", "keep_dims=False"]
)
def test_numpy_compare(
self,
ord: float,
core_dims: tuple[int, ...],
batch_dims: tuple[int, ...],
test_imag: bool,
keepdims: bool,
axis=None,
):
is_matrix = len(core_dims) == 2
has_batch = len(batch_dims) > 0
if ord in [np.inf, -np.inf] and not is_matrix:
pytest.skip("Infinity norm not defined for vectors")
if test_imag and is_matrix and ord == -2:
pytest.skip("Complex matrices not supported")
if has_batch and not is_matrix:
# Handle batched vectors by row-normalizing a matrix
axis = (-1,)
def test_numpy_compare(self):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
M = matrix("A", dtype=config.floatX) if test_imag:
V = vector("V", dtype=config.floatX) x_real, x_imag = rng.standard_normal((2, *batch_dims, *core_dims)).astype(
config.floatX
)
dtype = "complex128" if config.floatX.endswith("64") else "complex64"
X = (x_real + 1j * x_imag).astype(dtype)
else:
X = rng.standard_normal(batch_dims + core_dims).astype(config.floatX)
a = rng.random((4, 4)).astype(config.floatX) if batch_dims == ():
b = rng.random(4).astype(config.floatX) np_norm = np.linalg.norm(X, ord=ord, axis=axis, keepdims=keepdims)
else:
np_norm = np.stack(
[np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) for x in X]
)
A = ( pt_norm = norm(X, ord=ord, axis=axis, keepdims=keepdims)
[None, "fro", "inf", "-inf", 1, -1, None, "inf", "-inf", 0, 1, -1, 2, -2], f = function([], pt_norm, mode="FAST_COMPILE")
[M, M, M, M, M, M, V, V, V, V, V, V, V, V],
[a, a, a, a, a, a, b, b, b, b, b, b, b, b],
[None, "fro", inf, -inf, 1, -1, None, inf, -inf, 0, 1, -1, 2, -2],
)
for i in range(0, 14): utt.assert_allclose(np_norm, f())
f = function([A[1][i]], norm(A[1][i], A[0][i]))
t_n = f(A[2][i])
n_n = np.linalg.norm(A[2][i], A[3][i])
assert _allclose(n_n, t_n)
class TestTensorInv(utt.InferShapeTester): class TestTensorInv(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论