提交 a6b9bb91 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Deprecate extract_diag and linalg.trace in favor of numpy look-alikes

上级 d6113953
......@@ -3376,6 +3376,7 @@ def inverse_permutation(perm):
)
# TODO: optimization to insert ExtractDiag with view=True
class ExtractDiag(Op):
"""
Return specified diagonals.
......@@ -3526,8 +3527,12 @@ class ExtractDiag(Op):
self.axis2 = 1
extract_diag = ExtractDiag()
# TODO: optimization to insert ExtractDiag with view=True
def extract_diag(x):
warnings.warn(
"pytensor.tensor.extract_diag is deprecated. Use pytensor.tensor.diagonal instead.",
FutureWarning,
)
return diagonal(x)
def diagonal(a, offset=0, axis1=0, axis2=1):
......@@ -3554,6 +3559,15 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
return ExtractDiag(offset, axis1, axis2)(a)
def trace(a, offset=0, axis1=0, axis2=1):
"""
Returns the sum along diagonals of the array.
Equivalent to `numpy.trace`
"""
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1)
class AllocDiag(Op):
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
......@@ -4254,6 +4268,7 @@ __all__ = [
"full_like",
"empty",
"empty_like",
"trace",
"tril_indices",
"tril_indices_from",
"triu_indices",
......
import warnings
from functools import partial
from typing import Tuple
......@@ -9,7 +10,7 @@ from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import basic as at
from pytensor.tensor import math as tm
from pytensor.tensor.basic import as_tensor_variable, extract_diag
from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
......@@ -175,7 +176,11 @@ def trace(X):
"""
Returns the sum of diagonal elements of matrix X.
"""
return extract_diag(X).sum()
warnings.warn(
"pytensor.tensor.linalg.trace is deprecated. Use pytensor.tensor.trace instead.",
FutureWarning,
)
return diagonal(X).sum()
class Det(Op):
......
......@@ -77,6 +77,7 @@ from pytensor.tensor.basic import (
tensor_copy,
tensor_from_scalar,
tile,
trace,
tri,
tril,
tril_indices,
......@@ -4489,3 +4490,23 @@ def test_oriented_stack_functions(func):
with pytest.raises(ValueError):
func(a, a)
def test_trace():
x_val = np.ones((5, 4, 2))
x = at.as_tensor(x_val)
np.testing.assert_allclose(
trace(x).eval(),
np.trace(x_val),
)
np.testing.assert_allclose(
trace(x, offset=1, axis1=1, axis2=2).eval(),
np.trace(x_val, offset=1, axis1=1, axis2=2),
)
np.testing.assert_allclose(
trace(x, offset=-1, axis1=0, axis2=-1).eval(),
np.trace(x_val, offset=-1, axis1=0, axis2=-1),
)
......@@ -291,6 +291,7 @@ def test_slogdet():
def test_trace():
rng = np.random.default_rng(utt.fetch_seed())
x = matrix()
with pytest.warns(FutureWarning):
g = trace(x)
f = pytensor.function([x], g)
......@@ -302,6 +303,7 @@ def test_trace():
xx = vector()
ok = False
try:
with pytest.warns(FutureWarning):
trace(xx)
except TypeError:
ok = True
......
......@@ -351,6 +351,7 @@ class TestTensorInstanceMethods:
def test_trace(self):
X, _ = self.vars
x, _ = self.vals
with pytest.warns(FutureWarning):
assert_array_equal(X.trace().eval({X: x}), x.trace())
def test_ravel(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论