提交 a6b9bb91 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Deprecate extract_diag and linalg.trace in favor of numpy look-alikes

上级 d6113953
...@@ -3376,6 +3376,7 @@ def inverse_permutation(perm): ...@@ -3376,6 +3376,7 @@ def inverse_permutation(perm):
) )
# TODO: optimization to insert ExtractDiag with view=True
class ExtractDiag(Op): class ExtractDiag(Op):
""" """
Return specified diagonals. Return specified diagonals.
...@@ -3526,8 +3527,12 @@ class ExtractDiag(Op): ...@@ -3526,8 +3527,12 @@ class ExtractDiag(Op):
self.axis2 = 1 self.axis2 = 1
extract_diag = ExtractDiag() def extract_diag(x):
# TODO: optimization to insert ExtractDiag with view=True warnings.warn(
"pytensor.tensor.extract_diag is deprecated. Use pytensor.tensor.diagonal instead.",
FutureWarning,
)
return diagonal(x)
def diagonal(a, offset=0, axis1=0, axis2=1): def diagonal(a, offset=0, axis1=0, axis2=1):
...@@ -3554,6 +3559,15 @@ def diagonal(a, offset=0, axis1=0, axis2=1): ...@@ -3554,6 +3559,15 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
return ExtractDiag(offset, axis1, axis2)(a) return ExtractDiag(offset, axis1, axis2)(a)
def trace(a, offset=0, axis1=0, axis2=1):
"""
Returns the sum along diagonals of the array.
Equivalent to `numpy.trace`
"""
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1)
class AllocDiag(Op): class AllocDiag(Op):
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix.""" """An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
...@@ -4254,6 +4268,7 @@ __all__ = [ ...@@ -4254,6 +4268,7 @@ __all__ = [
"full_like", "full_like",
"empty", "empty",
"empty_like", "empty_like",
"trace",
"tril_indices", "tril_indices",
"tril_indices_from", "tril_indices_from",
"triu_indices", "triu_indices",
......
import warnings
from functools import partial from functools import partial
from typing import Tuple from typing import Tuple
...@@ -9,7 +10,7 @@ from pytensor.graph.basic import Apply ...@@ -9,7 +10,7 @@ from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor import math as tm from pytensor.tensor import math as tm
from pytensor.tensor.basic import as_tensor_variable, extract_diag from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
...@@ -175,7 +176,11 @@ def trace(X): ...@@ -175,7 +176,11 @@ def trace(X):
""" """
Returns the sum of diagonal elements of matrix X. Returns the sum of diagonal elements of matrix X.
""" """
return extract_diag(X).sum() warnings.warn(
"pytensor.tensor.linalg.trace is deprecated. Use pytensor.tensor.trace instead.",
FutureWarning,
)
return diagonal(X).sum()
class Det(Op): class Det(Op):
......
...@@ -77,6 +77,7 @@ from pytensor.tensor.basic import ( ...@@ -77,6 +77,7 @@ from pytensor.tensor.basic import (
tensor_copy, tensor_copy,
tensor_from_scalar, tensor_from_scalar,
tile, tile,
trace,
tri, tri,
tril, tril,
tril_indices, tril_indices,
...@@ -4489,3 +4490,23 @@ def test_oriented_stack_functions(func): ...@@ -4489,3 +4490,23 @@ def test_oriented_stack_functions(func):
with pytest.raises(ValueError): with pytest.raises(ValueError):
func(a, a) func(a, a)
def test_trace():
x_val = np.ones((5, 4, 2))
x = at.as_tensor(x_val)
np.testing.assert_allclose(
trace(x).eval(),
np.trace(x_val),
)
np.testing.assert_allclose(
trace(x, offset=1, axis1=1, axis2=2).eval(),
np.trace(x_val, offset=1, axis1=1, axis2=2),
)
np.testing.assert_allclose(
trace(x, offset=-1, axis1=0, axis2=-1).eval(),
np.trace(x_val, offset=-1, axis1=0, axis2=-1),
)
...@@ -291,7 +291,8 @@ def test_slogdet(): ...@@ -291,7 +291,8 @@ def test_slogdet():
def test_trace(): def test_trace():
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
x = matrix() x = matrix()
g = trace(x) with pytest.warns(FutureWarning):
g = trace(x)
f = pytensor.function([x], g) f = pytensor.function([x], g)
for shp in [(2, 3), (3, 2), (3, 3)]: for shp in [(2, 3), (3, 2), (3, 3)]:
...@@ -302,7 +303,8 @@ def test_trace(): ...@@ -302,7 +303,8 @@ def test_trace():
xx = vector() xx = vector()
ok = False ok = False
try: try:
trace(xx) with pytest.warns(FutureWarning):
trace(xx)
except TypeError: except TypeError:
ok = True ok = True
except ValueError: except ValueError:
......
...@@ -351,7 +351,8 @@ class TestTensorInstanceMethods: ...@@ -351,7 +351,8 @@ class TestTensorInstanceMethods:
def test_trace(self): def test_trace(self):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
assert_array_equal(X.trace().eval({X: x}), x.trace()) with pytest.warns(FutureWarning):
assert_array_equal(X.trace().eval({X: x}), x.trace())
def test_ravel(self): def test_ravel(self):
X, _ = self.vars X, _ = self.vars
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论