提交 22c8c46d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove deprecated aesara.tensor.nlinalg.AllocDiag

上级 16f98e4f
...@@ -16,8 +16,10 @@ from aesara.scan.op import Scan ...@@ -16,8 +16,10 @@ from aesara.scan.op import Scan
from aesara.scan.utils import scan_args as ScanArgs from aesara.scan.utils import scan_args as ScanArgs
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
AllocDiag,
AllocEmpty, AllocEmpty,
ARange, ARange,
ExtractDiag,
Eye, Eye,
Join, Join,
MakeVector, MakeVector,
...@@ -41,11 +43,9 @@ from aesara.tensor.extra_ops import ( ...@@ -41,11 +43,9 @@ from aesara.tensor.extra_ops import (
from aesara.tensor.math import Dot, MaxAndArgmax from aesara.tensor.math import Dot, MaxAndArgmax
from aesara.tensor.nlinalg import ( from aesara.tensor.nlinalg import (
SVD, SVD,
AllocDiag,
Det, Det,
Eig, Eig,
Eigh, Eigh,
ExtractDiag,
MatrixInverse, MatrixInverse,
QRFull, QRFull,
QRIncomplete, QRIncomplete,
...@@ -267,6 +267,16 @@ def jax_funcify_Second(op): ...@@ -267,6 +267,16 @@ def jax_funcify_Second(op):
return second return second
@jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op):
offset = op.offset
def allocdiag(v, offset=offset):
return jnp.diag(v, k=offset)
return allocdiag
@jax_funcify.register(AllocEmpty) @jax_funcify.register(AllocEmpty)
def jax_funcify_AllocEmpty(op): def jax_funcify_AllocEmpty(op):
def allocempty(*shape): def allocempty(*shape):
...@@ -835,14 +845,6 @@ def jax_funcify_Cholesky(op): ...@@ -835,14 +845,6 @@ def jax_funcify_Cholesky(op):
return cholesky return cholesky
@jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op):
def alloc_diag(x):
return jnp.diag(x)
return alloc_diag
@jax_funcify.register(Solve) @jax_funcify.register(Solve)
def jax_funcify_Solve(op): def jax_funcify_Solve(op):
......
from aesara.sandbox.linalg.ops import psd, spectral_radius_bound from aesara.sandbox.linalg.ops import psd, spectral_radius_bound
from aesara.tensor.nlinalg import ( from aesara.tensor.nlinalg import det, eig, eigh, matrix_inverse, trace
alloc_diag,
det,
diag,
eig,
eigh,
extract_diag,
matrix_inverse,
trace,
)
from aesara.tensor.slinalg import cholesky, eigvalsh, solve from aesara.tensor.slinalg import cholesky, eigvalsh, solve
...@@ -16,13 +16,7 @@ from aesara.tensor.exceptions import NotScalarConstantError ...@@ -16,13 +16,7 @@ from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, Prod, dot, log from aesara.tensor.math import Dot, Prod, dot, log
from aesara.tensor.math import pow as aet_pow from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import prod from aesara.tensor.math import prod
from aesara.tensor.nlinalg import ( from aesara.tensor.nlinalg import MatrixInverse, det, matrix_inverse, trace
MatrixInverse,
det,
extract_diag,
matrix_inverse,
trace,
)
from aesara.tensor.slinalg import Cholesky, Solve, cholesky, imported_scipy, solve from aesara.tensor.slinalg import Cholesky, Solve, cholesky, imported_scipy, solve
...@@ -320,7 +314,7 @@ def local_det_chol(fgraph, node): ...@@ -320,7 +314,7 @@ def local_det_chol(fgraph, node):
for (cl, xpos) in fgraph.clients[x]: for (cl, xpos) in fgraph.clients[x]:
if isinstance(cl.op, Cholesky): if isinstance(cl.op, Cholesky):
L = cl.outputs[0] L = cl.outputs[0]
return [prod(extract_diag(L) ** 2)] return [prod(aet.extract_diag(L) ** 2)]
@register_canonicalize @register_canonicalize
......
...@@ -3811,6 +3811,10 @@ class ExtractDiag(Op): ...@@ -3811,6 +3811,10 @@ class ExtractDiag(Op):
self.axis2 = 1 self.axis2 = 1
extract_diag = ExtractDiag()
# TODO: optimization to insert ExtractDiag with view=True
def diagonal(a, offset=0, axis1=0, axis2=1): def diagonal(a, offset=0, axis1=0, axis2=1):
""" """
A helper function for `ExtractDiag`. It accepts tensor with A helper function for `ExtractDiag`. It accepts tensor with
...@@ -4298,4 +4302,5 @@ __all__ = [ ...@@ -4298,4 +4302,5 @@ __all__ = [
"constant", "constant",
"as_tensor_variable", "as_tensor_variable",
"as_tensor", "as_tensor",
"extract_diag",
] ]
...@@ -18,7 +18,6 @@ from aesara.misc.safe_asarray import _asarray ...@@ -18,7 +18,6 @@ from aesara.misc.safe_asarray import _asarray
from aesara.scalar import int32 as int_t from aesara.scalar import int32 as int_t
from aesara.scalar import upcast from aesara.scalar import upcast
from aesara.tensor import basic as aet from aesara.tensor import basic as aet
from aesara.tensor import nlinalg
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import abs_ from aesara.tensor.math import abs_
from aesara.tensor.math import all as aet_all from aesara.tensor.math import all as aet_all
...@@ -961,7 +960,7 @@ class FillDiagonal(Op): ...@@ -961,7 +960,7 @@ class FillDiagonal(Op):
) )
wr_a = fill_diagonal(grad, 0) # valid for any number of dimensions wr_a = fill_diagonal(grad, 0) # valid for any number of dimensions
# diag is only valid for matrices # diag is only valid for matrices
wr_val = nlinalg.diag(grad).sum() wr_val = aet.diag(grad).sum()
return [wr_a, wr_val] return [wr_a, wr_val]
......
import logging import logging
import warnings
from functools import partial from functools import partial
import numpy as np import numpy as np
...@@ -10,7 +9,7 @@ from aesara.graph.basic import Apply ...@@ -10,7 +9,7 @@ from aesara.graph.basic import Apply
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.tensor import basic as aet from aesara.tensor import basic as aet
from aesara.tensor import math as tm from aesara.tensor import math as tm
from aesara.tensor.basic import ExtractDiag, as_tensor_variable from aesara.tensor.basic import as_tensor_variable, extract_diag
from aesara.tensor.type import dvector, lscalar, matrix, scalar, vector from aesara.tensor.type import dvector, lscalar, matrix, scalar, vector
...@@ -168,62 +167,6 @@ def matrix_dot(*args): ...@@ -168,62 +167,6 @@ def matrix_dot(*args):
return rval return rval
class AllocDiag(Op):
"""
Allocates a square matrix with the given vector as its diagonal.
"""
__props__ = ()
def make_node(self, _x):
warnings.warn(
"DeprecationWarning: aesara.tensor.nlinalg.AllocDiag"
"is deprecated, please use aesara.tensor.basic.AllocDiag"
"instead.",
category=DeprecationWarning,
)
x = as_tensor_variable(_x)
if x.type.ndim != 1:
raise TypeError("AllocDiag only works on vectors", _x)
return Apply(self, [x], [matrix(dtype=x.type.dtype)])
def grad(self, inputs, g_outputs):
return [extract_diag(g_outputs[0])]
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
if x.ndim != 1:
raise TypeError(x)
z[0] = np.diag(x)
def infer_shape(self, fgraph, node, shapes):
(x_s,) = shapes
return [(x_s[0], x_s[0])]
alloc_diag = AllocDiag()
extract_diag = ExtractDiag()
# TODO: optimization to insert ExtractDiag with view=True
def diag(x):
"""
Numpy-compatibility method
If `x` is a matrix, return its diagonal.
If `x` is a vector return a matrix with it as its diagonal.
* This method does not support the `k` argument that numpy supports.
"""
xx = as_tensor_variable(x)
if xx.type.ndim == 1:
return alloc_diag(xx)
elif xx.type.ndim == 2:
return extract_diag(xx)
else:
raise TypeError("diag requires vector or matrix argument", x)
def trace(X): def trace(X):
""" """
Returns the sum of diagonal elements of matrix X. Returns the sum of diagonal elements of matrix X.
......
...@@ -265,7 +265,7 @@ def test_jax_basic(): ...@@ -265,7 +265,7 @@ def test_jax_basic():
], ],
) )
out = aet_nlinalg.alloc_diag(b) out = aet.diag(b)
out_fg = FunctionGraph([b], [out]) out_fg = FunctionGraph([b], [out])
compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])
......
...@@ -3225,19 +3225,21 @@ class TestSize: ...@@ -3225,19 +3225,21 @@ class TestSize:
class TestDiag: class TestDiag:
# Test that aet.diag has the same behavior as np.diag. """
# np.diag has two behaviors: Test that linalg.diag has the same behavior as numpy.diag.
# numpy.diag has two behaviors:
# (1) when given a vector, it returns a matrix with that vector as the (1) when given a vector, it returns a matrix with that vector as the
# diagonal. diagonal.
# (2) when given a matrix, returns a vector which is the diagonal of the (2) when given a matrix, returns a vector which is the diagonal of the
# matrix. matrix.
#
# (1) and (2) are tested by test_alloc_diag and test_extract_diag (1) and (2) are tested by test_alloc_diag and test_extract_diag
# respectively. respectively.
#
# test_diag test makes sure that linalg.diag instantiates test_diag test makes sure that linalg.diag instantiates
# the right op based on the dimension of the input. the right op based on the dimension of the input.
"""
def setup_method(self): def setup_method(self):
self.mode = None self.mode = None
self.shared = shared self.shared = shared
......
...@@ -10,17 +10,12 @@ from aesara.configdefaults import config ...@@ -10,17 +10,12 @@ from aesara.configdefaults import config
from aesara.tensor.math import _allclose from aesara.tensor.math import _allclose
from aesara.tensor.nlinalg import ( from aesara.tensor.nlinalg import (
SVD, SVD,
AllocDiag,
Eig, Eig,
ExtractDiag,
MatrixInverse, MatrixInverse,
TensorInv, TensorInv,
alloc_diag,
det, det,
diag,
eig, eig,
eigh, eigh,
extract_diag,
matrix_dot, matrix_dot,
matrix_inverse, matrix_inverse,
matrix_power, matrix_power,
...@@ -33,7 +28,6 @@ from aesara.tensor.nlinalg import ( ...@@ -33,7 +28,6 @@ from aesara.tensor.nlinalg import (
trace, trace,
) )
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType,
lmatrix, lmatrix,
lscalar, lscalar,
matrix, matrix,
...@@ -287,136 +281,6 @@ def test_det_shape(): ...@@ -287,136 +281,6 @@ def test_det_shape():
assert np.all(f(r).shape == f_shape(r)) assert np.all(f(r).shape == f_shape(r))
class TestDiag:
"""
Test that linalg.diag has the same behavior as numpy.diag.
numpy.diag has two behaviors:
(1) when given a vector, it returns a matrix with that vector as the
diagonal.
(2) when given a matrix, returns a vector which is the diagonal of the
matrix.
(1) and (2) are tested by test_alloc_diag and test_extract_diag
respectively.
test_diag test makes sure that linalg.diag instantiates
the right op based on the dimension of the input.
"""
def setup_method(self):
self.mode = None
self.shared = aesara.shared
self.floatX = config.floatX
self.type = TensorType
def test_alloc_diag(self):
rng = np.random.RandomState(utt.fetch_seed())
x = vector()
g = alloc_diag(x)
f = aesara.function([x], g)
# test "normal" scenario (5x5 matrix) and special cases of 0x0 and 1x1
for shp in [5, 0, 1]:
m = rng.rand(shp).astype(self.floatX)
v = np.diag(m)
r = f(m)
# The right matrix is created
assert (r == v).all()
# Test we accept only vectors
xx = matrix()
ok = False
try:
alloc_diag(xx)
except TypeError:
ok = True
assert ok
# Test infer_shape
f = aesara.function([x], g.shape)
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert sum([node.op.__class__ == AllocDiag for node in topo]) == 0
for shp in [5, 0, 1]:
m = rng.rand(shp).astype(self.floatX)
assert (f(m) == m.shape).all()
def test_alloc_diag_grad(self):
rng = np.random.RandomState(utt.fetch_seed())
x = rng.rand(5)
utt.verify_grad(alloc_diag, [x], rng=rng)
def test_diag(self):
# test that it builds a matrix with given diagonal when using
# vector inputs
x = vector()
y = diag(x)
assert y.owner.op.__class__ == AllocDiag
# test that it extracts the diagonal when using matrix input
x = matrix()
y = extract_diag(x)
assert y.owner.op.__class__ == ExtractDiag
# not testing the view=True case since it is not used anywhere.
def test_extract_diag(self):
rng = np.random.RandomState(utt.fetch_seed())
m = rng.rand(2, 3).astype(self.floatX)
x = self.shared(m)
g = extract_diag(x)
f = aesara.function([], g)
assert [
isinstance(node.inputs[0].type, self.type)
for node in f.maker.fgraph.toposort()
if isinstance(node.op, ExtractDiag)
] == [True]
for shp in [(2, 3), (3, 2), (3, 3), (1, 1), (0, 0)]:
m = rng.rand(*shp).astype(self.floatX)
x.set_value(m)
v = np.diag(m)
r = f()
# The right diagonal is extracted
assert (r == v).all()
# Test we accept only matrix
xx = vector()
ok = False
try:
extract_diag(xx)
except TypeError:
ok = True
except ValueError:
ok = True
assert ok
# Test infer_shape
f = aesara.function([], g.shape)
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert sum([node.op.__class__ == ExtractDiag for node in topo]) == 0
for shp in [(2, 3), (3, 2), (3, 3)]:
m = rng.rand(*shp).astype(self.floatX)
x.set_value(m)
assert f() == min(shp)
def test_extract_diag_grad(self):
rng = np.random.RandomState(utt.fetch_seed())
x = rng.rand(5, 4).astype(self.floatX)
utt.verify_grad(extract_diag, [x], rng=rng)
@pytest.mark.slow
def test_extract_diag_empty(self):
c = self.shared(np.array([[], []], self.floatX))
f = aesara.function([], extract_diag(c), mode=self.mode)
assert [
isinstance(node.inputs[0].type, self.type)
for node in f.maker.fgraph.toposort()
if isinstance(node.op, ExtractDiag)
] == [True]
def test_trace(): def test_trace():
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
x = matrix() x = matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论