提交 31d593d8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Support all gradient cases for ExtractDiag

Also fixes wrong gradient for negative offsets
上级 c011572c
......@@ -6,7 +6,6 @@ manipulation of tensors.
"""
import builtins
import warnings
from functools import partial
from numbers import Number
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
......@@ -20,7 +19,7 @@ import pytensor
import pytensor.scalar.sharedvar
from pytensor import compile, config, printing
from pytensor import scalar as aes
from pytensor.gradient import DisconnectedType, grad_not_implemented, grad_undefined
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
......@@ -3407,15 +3406,18 @@ class ExtractDiag(Op):
self.view = view
if self.view:
self.view_map = {0: [0]}
self.offset = offset
if axis1 < 0 or axis2 < 0:
raise NotImplementedError(
"ExtractDiag does not support negative axis. Use pytensor.tensor.diagonal instead."
)
if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")
# Sort axis
if axis1 > axis2:
axis1, axis2, offset = axis2, axis1, -offset
self.axis1 = axis1
self.axis2 = axis2
self.offset = offset
def make_node(self, x):
x = as_tensor_variable(x)
......@@ -3436,20 +3438,29 @@ class ExtractDiag(Op):
z[0] = z[0].copy()
def grad(self, inputs, gout):
# Avoid circular import
from pytensor.tensor.subtensor import set_subtensor
(x,) = inputs
(gz,) = gout
if x.ndim == 2:
x = zeros_like(x)
xdiag = AllocDiag(offset=self.offset)(gz)
return [
pytensor.tensor.subtensor.set_subtensor(
x[: xdiag.shape[0], : xdiag.shape[1]], xdiag
)
]
axis1, axis2, offset = self.axis1, self.axis2, self.offset
# Start with zeros (and axes in the front)
x_grad = zeros_like(moveaxis(x, (axis1, axis2), (0, 1)))
# Fill zeros with output diagonal
xdiag = AllocDiag(offset=0, axis1=0, axis2=1)(gz)
z_len = xdiag.shape[0]
if offset >= 0:
diag_slices = (slice(None, z_len), slice(offset, offset + z_len))
else:
warnings.warn("Gradient of ExtractDiag only works for matrices.")
return [grad_not_implemented(self, 0, x)]
diag_slices = (slice(abs(offset), abs(offset) + z_len), slice(None, z_len))
x_grad = set_subtensor(x_grad[diag_slices], xdiag)
# Put axes back in their original positions
x_grad = moveaxis(x_grad, (0, 1), (axis1, axis2))
return [x_grad]
def infer_shape(self, fgraph, node, shapes):
from pytensor.tensor.math import clip, minimum
......@@ -3514,10 +3525,7 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
class AllocDiag(Op):
"""An `Op` that copies a vector to the diagonal of an empty matrix.
It does the inverse of `ExtractDiag`.
"""
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
__props__ = ("offset", "axis1", "axis2")
......
......@@ -3552,16 +3552,10 @@ class TestDiag:
"""
Test that linalg.diag has the same behavior as numpy.diag.
numpy.diag has two behaviors:
(1) when given a vector, it returns a matrix with that vector as the
diagonal.
(2) when given a matrix, returns a vector which is the diagonal of the
matrix.
(1) when given a vector, it returns a matrix with that vector as the diagonal.
(2) when given a matrix, returns a vector which is the diagonal of the matrix.
(1) and (2) are tested by test_alloc_diag and test_extract_diag
respectively.
test_diag test makes sure that linalg.diag instantiates
the right op based on the dimension of the input.
(1) and (2) are further tested by TestAllocDiag and TestExtractDiag, respectively.
"""
def setup_method(self):
......@@ -3571,6 +3565,7 @@ class TestDiag:
self.type = TensorType
def test_diag(self):
"""Makes sure that diag instantiates the right op based on the dimension of the input."""
rng = np.random.default_rng(utt.fetch_seed())
# test vector input
......@@ -3609,38 +3604,55 @@ class TestDiag:
f = function([], g)
assert np.array_equal(f(), np.diag(xx))
def test_infer_shape(self):
class TestExtractDiag:
@pytest.mark.parametrize("axis1, axis2", [(0, 1), (1, 0)])
@pytest.mark.parametrize("offset", (-1, 0, 2))
def test_infer_shape(self, offset, axis1, axis2):
rng = np.random.default_rng(utt.fetch_seed())
x = vector()
g = diag(x)
f = pytensor.function([x], g.shape)
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert sum(isinstance(node.op, AllocDiag) for node in topo) == 0
for shp in [5, 0, 1]:
m = rng.random(shp).astype(self.floatX)
assert (f(m) == np.diag(m).shape).all()
x = matrix()
g = diag(x)
x = matrix("x")
g = ExtractDiag(offset=offset, axis1=axis1, axis2=axis2)(x)
f = pytensor.function([x], g.shape)
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert sum(isinstance(node.op, ExtractDiag) for node in topo) == 0
for shp in [(5, 3), (3, 5), (5, 1), (1, 5), (5, 0), (0, 5), (1, 0), (0, 1)]:
m = rng.random(shp).astype(self.floatX)
assert (f(m) == np.diag(m).shape).all()
m = rng.random(shp).astype(config.floatX)
assert (
f(m) == np.diagonal(m, offset=offset, axis1=axis1, axis2=axis2).shape
).all()
def test_diag_grad(self):
@pytest.mark.parametrize("axis1, axis2", [(0, 1), (1, 0)])
@pytest.mark.parametrize("offset", (0, 1, -1))
def test_grad_2d(self, offset, axis1, axis2):
diag_fn = ExtractDiag(offset=offset, axis1=axis1, axis2=axis2)
rng = np.random.default_rng(utt.fetch_seed())
x = rng.random(5)
utt.verify_grad(diag, [x], rng=rng)
x = rng.random((5, 3))
utt.verify_grad(diag, [x], rng=rng)
utt.verify_grad(diag_fn, [x], rng=rng)
@pytest.mark.parametrize(
"axis1, axis2",
[
(0, 1),
(1, 0),
(1, 2),
(2, 1),
(0, 2),
(2, 0),
],
)
@pytest.mark.parametrize("offset", (0, 1, -1))
def test_grad_3d(self, offset, axis1, axis2):
diag_fn = ExtractDiag(offset=offset, axis1=axis1, axis2=axis2)
rng = np.random.default_rng(utt.fetch_seed())
x = rng.random((5, 4, 3))
utt.verify_grad(diag_fn, [x], rng=rng)
class TestAllocDiag:
# TODO: Separate perform, grad and infer_shape tests
def setup_method(self):
self.alloc_diag = AllocDiag
self.mode = pytensor.compile.mode.get_default_mode()
......@@ -3674,7 +3686,7 @@ class TestAllocDiag:
(-2, 0, 1),
(-1, 1, 2),
]:
# Test AllocDiag values
# Test perform
if np.maximum(axis1, axis2) > len(test_val.shape):
continue
adiag_op = self.alloc_diag(offset=offset, axis1=axis1, axis2=axis2)
......@@ -3688,7 +3700,6 @@ class TestAllocDiag:
# Test infer_shape
f_shape = pytensor.function([x], adiag_op(x).shape, mode="FAST_RUN")
# pytensor.printing.debugprint(f_shape.maker.fgraph.outputs[0])
output_shape = f_shape(test_val)
assert not any(
isinstance(node.op, self.alloc_diag)
......@@ -3699,6 +3710,7 @@ class TestAllocDiag:
).shape
assert np.all(rediag_shape == test_val.shape)
# Test grad
diag_x = adiag_op(x)
sum_diag_x = at_sum(diag_x)
grad_x = pytensor.grad(sum_diag_x, x)
......@@ -3710,7 +3722,6 @@ class TestAllocDiag:
true_grad_input = np.diagonal(
grad_diag_input, offset=offset, axis1=axis1, axis2=axis2
)
assert np.all(true_grad_input == grad_input)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论