Unverified 提交 5e612abb authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Add `matrix_transpose` and `.mT` property helpers (#702)

上级 0a13fbd6
......@@ -1982,6 +1982,62 @@ def transpose(x, axes=None):
return ret
def matrix_transpose(x: "TensorLike") -> TensorVariable:
"""
Transposes each 2-dimensional matrix tensor along the last two dimensions of a higher-dimensional tensor.
Parameters
----------
x : array_like
Input tensor with shape (..., M, N), where `M` and `N` represent the dimensions
of the matrices. Each matrix is of shape (M, N).
Returns
-------
out : tensor
Transposed tensor with the shape (..., N, M), where each 2-dimensional matrix
in the input tensor has been transposed along the last two dimensions.
Examples
--------
>>> import pytensor as pt
>>> import numpy as np
>>> x = np.arange(24).reshape((2, 3, 4))
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
>>> pt.matrix_transpose(x).eval()
[[[ 0 4 8]
[ 1 5 9]
[ 2 6 10]
[ 3 7 11]]
[[12 16 20]
[13 17 21]
[14 18 22]
[15 19 23]]]
Notes
-----
This function transposes each 2-dimensional matrix within the input tensor along
the last two dimensions. If the input tensor has more than two dimensions, it
transposes each 2-dimensional matrix independently while preserving other dimensions.
"""
x = as_tensor_variable(x)
if x.ndim < 2:
raise ValueError(
f"Input array must be at least 2-dimensional, but it is {x.ndim}"
)
return swapaxes(x, -1, -2)
def split(x, splits_size, n_splits, axis=0):
the_split = Split(n_splits)
return the_split(x, axis, splits_size)
......@@ -4302,6 +4358,7 @@ __all__ = [
"join",
"split",
"transpose",
"matrix_transpose",
"extract_constant",
"default",
"tensor_copy",
......
......@@ -2,7 +2,7 @@ import logging
from typing import cast
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
......@@ -43,11 +43,6 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
return False
def _T(x: TensorVariable) -> TensorVariable:
"""Matrix transpose for potentially higher dimensionality tensors"""
return swapaxes(x, -1, -2)
@register_canonicalize
@node_rewriter([DimShuffle])
def transinv_to_invtrans(fgraph, node):
......@@ -83,9 +78,9 @@ def inv_as_solve(fgraph, node):
):
x = r.owner.inputs[0]
if getattr(x.tag, "symmetric", None) is True:
return [_T(solve(x, _T(l)))]
return [solve(x, (l.mT)).mT]
else:
return [_T(solve(_T(x), _T(l)))]
return [solve((x.mT), (l.mT)).mT]
@register_stabilize
......@@ -216,7 +211,7 @@ def psd_solve_with_chol(fgraph, node):
# __if__ no other Op makes use of the L matrix during the
# stabilization
Li_b = solve_triangular(L, b, lower=True, b_ndim=2)
x = solve_triangular(_T(L), Li_b, lower=False, b_ndim=2)
x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2)
return [x]
......
......@@ -232,6 +232,10 @@ class _tensor_py_operators:
def T(self):
return pt.basic.transpose(self)
@property
def mT(self):
return pt.basic.matrix_transpose(self)
def transpose(self, *axes):
"""Transpose this array.
......
......@@ -3813,6 +3813,7 @@ def test_transpose():
)
t1, t2, t3, t1b, t2b, t3b, t2c, t3c, t2d, t3d = f(x1v, x2v, x3v)
assert t1.shape == np.transpose(x1v).shape
assert t2.shape == np.transpose(x2v).shape
assert t3.shape == np.transpose(x3v).shape
......@@ -3838,6 +3839,23 @@ def test_transpose():
assert ptb.transpose(dmatrix()).name is None
def test_matrix_transpose():
with pytest.raises(ValueError, match="Input array must be at least 2-dimensional"):
ptb.matrix_transpose(dvector("x1"))
x2 = dmatrix("x2")
x3 = dtensor3("x3")
var1 = ptb.matrix_transpose(x2)
expected_var1 = swapaxes(x2, -1, -2)
var2 = x3.mT
expected_var2 = swapaxes(x3, -1, -2)
assert equal_computations([var1], [expected_var1])
assert equal_computations([var2], [expected_var2])
def test_stacklists():
a, b, c, d = map(scalar, "abcd")
X = stacklists([[a, b], [c, d]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论