Unverified 提交 5e612abb authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Add `matrix_transpose` and `.mT` property helpers (#702)

上级 0a13fbd6
...@@ -1982,6 +1982,62 @@ def transpose(x, axes=None): ...@@ -1982,6 +1982,62 @@ def transpose(x, axes=None):
return ret return ret
def matrix_transpose(x: "TensorLike") -> TensorVariable:
"""
Transposes each 2-dimensional matrix tensor along the last two dimensions of a higher-dimensional tensor.
Parameters
----------
x : array_like
Input tensor with shape (..., M, N), where `M` and `N` represent the dimensions
of the matrices. Each matrix is of shape (M, N).
Returns
-------
out : tensor
Transposed tensor with the shape (..., N, M), where each 2-dimensional matrix
in the input tensor has been transposed along the last two dimensions.
Examples
--------
>>> import pytensor as pt
>>> import numpy as np
>>> x = np.arange(24).reshape((2, 3, 4))
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
>>> pt.matrix_transpose(x).eval()
[[[ 0 4 8]
[ 1 5 9]
[ 2 6 10]
[ 3 7 11]]
[[12 16 20]
[13 17 21]
[14 18 22]
[15 19 23]]]
Notes
-----
This function transposes each 2-dimensional matrix within the input tensor along
the last two dimensions. If the input tensor has more than two dimensions, it
transposes each 2-dimensional matrix independently while preserving other dimensions.
"""
x = as_tensor_variable(x)
if x.ndim < 2:
raise ValueError(
f"Input array must be at least 2-dimensional, but it is {x.ndim}"
)
return swapaxes(x, -1, -2)
def split(x, splits_size, n_splits, axis=0): def split(x, splits_size, n_splits, axis=0):
the_split = Split(n_splits) the_split = Split(n_splits)
return the_split(x, axis, splits_size) return the_split(x, axis, splits_size)
...@@ -4302,6 +4358,7 @@ __all__ = [ ...@@ -4302,6 +4358,7 @@ __all__ = [
"join", "join",
"split", "split",
"transpose", "transpose",
"matrix_transpose",
"extract_constant", "extract_constant",
"default", "default",
"tensor_copy", "tensor_copy",
......
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
from typing import cast from typing import cast
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
...@@ -43,11 +43,6 @@ def is_matrix_transpose(x: TensorVariable) -> bool: ...@@ -43,11 +43,6 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
return False return False
def _T(x: TensorVariable) -> TensorVariable:
"""Matrix transpose for potentially higher dimensionality tensors"""
return swapaxes(x, -1, -2)
@register_canonicalize @register_canonicalize
@node_rewriter([DimShuffle]) @node_rewriter([DimShuffle])
def transinv_to_invtrans(fgraph, node): def transinv_to_invtrans(fgraph, node):
...@@ -83,9 +78,9 @@ def inv_as_solve(fgraph, node): ...@@ -83,9 +78,9 @@ def inv_as_solve(fgraph, node):
): ):
x = r.owner.inputs[0] x = r.owner.inputs[0]
if getattr(x.tag, "symmetric", None) is True: if getattr(x.tag, "symmetric", None) is True:
return [_T(solve(x, _T(l)))] return [solve(x, (l.mT)).mT]
else: else:
return [_T(solve(_T(x), _T(l)))] return [solve((x.mT), (l.mT)).mT]
@register_stabilize @register_stabilize
...@@ -216,7 +211,7 @@ def psd_solve_with_chol(fgraph, node): ...@@ -216,7 +211,7 @@ def psd_solve_with_chol(fgraph, node):
# __if__ no other Op makes use of the L matrix during the # __if__ no other Op makes use of the L matrix during the
# stabilization # stabilization
Li_b = solve_triangular(L, b, lower=True, b_ndim=2) Li_b = solve_triangular(L, b, lower=True, b_ndim=2)
x = solve_triangular(_T(L), Li_b, lower=False, b_ndim=2) x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2)
return [x] return [x]
......
...@@ -232,6 +232,10 @@ class _tensor_py_operators: ...@@ -232,6 +232,10 @@ class _tensor_py_operators:
def T(self): def T(self):
return pt.basic.transpose(self) return pt.basic.transpose(self)
@property
def mT(self):
return pt.basic.matrix_transpose(self)
def transpose(self, *axes): def transpose(self, *axes):
"""Transpose this array. """Transpose this array.
......
...@@ -3813,6 +3813,7 @@ def test_transpose(): ...@@ -3813,6 +3813,7 @@ def test_transpose():
) )
t1, t2, t3, t1b, t2b, t3b, t2c, t3c, t2d, t3d = f(x1v, x2v, x3v) t1, t2, t3, t1b, t2b, t3b, t2c, t3c, t2d, t3d = f(x1v, x2v, x3v)
assert t1.shape == np.transpose(x1v).shape assert t1.shape == np.transpose(x1v).shape
assert t2.shape == np.transpose(x2v).shape assert t2.shape == np.transpose(x2v).shape
assert t3.shape == np.transpose(x3v).shape assert t3.shape == np.transpose(x3v).shape
...@@ -3838,6 +3839,23 @@ def test_transpose(): ...@@ -3838,6 +3839,23 @@ def test_transpose():
assert ptb.transpose(dmatrix()).name is None assert ptb.transpose(dmatrix()).name is None
def test_matrix_transpose():
with pytest.raises(ValueError, match="Input array must be at least 2-dimensional"):
ptb.matrix_transpose(dvector("x1"))
x2 = dmatrix("x2")
x3 = dtensor3("x3")
var1 = ptb.matrix_transpose(x2)
expected_var1 = swapaxes(x2, -1, -2)
var2 = x3.mT
expected_var2 = swapaxes(x3, -1, -2)
assert equal_computations([var1], [expected_var1])
assert equal_computations([var2], [expected_var2])
def test_stacklists(): def test_stacklists():
a, b, c, d = map(scalar, "abcd") a, b, c, d = map(scalar, "abcd")
X = stacklists([[a, b], [c, d]]) X = stacklists([[a, b], [c, d]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论