提交 31d593d8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Support all gradient cases for ExtractDiag

Also fixes wrong gradient for negative offsets
上级 c011572c
...@@ -6,7 +6,6 @@ manipulation of tensors. ...@@ -6,7 +6,6 @@ manipulation of tensors.
""" """
import builtins import builtins
import warnings
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
...@@ -20,7 +19,7 @@ import pytensor ...@@ -20,7 +19,7 @@ import pytensor
import pytensor.scalar.sharedvar import pytensor.scalar.sharedvar
from pytensor import compile, config, printing from pytensor import compile, config, printing
from pytensor import scalar as aes from pytensor import scalar as aes
from pytensor.gradient import DisconnectedType, grad_not_implemented, grad_undefined from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
...@@ -3407,15 +3406,18 @@ class ExtractDiag(Op): ...@@ -3407,15 +3406,18 @@ class ExtractDiag(Op):
self.view = view self.view = view
if self.view: if self.view:
self.view_map = {0: [0]} self.view_map = {0: [0]}
self.offset = offset
if axis1 < 0 or axis2 < 0: if axis1 < 0 or axis2 < 0:
raise NotImplementedError( raise NotImplementedError(
"ExtractDiag does not support negative axis. Use pytensor.tensor.diagonal instead." "ExtractDiag does not support negative axis. Use pytensor.tensor.diagonal instead."
) )
if axis1 == axis2: if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same") raise ValueError("axis1 and axis2 cannot be the same")
# Sort axis
if axis1 > axis2:
axis1, axis2, offset = axis2, axis1, -offset
self.axis1 = axis1 self.axis1 = axis1
self.axis2 = axis2 self.axis2 = axis2
self.offset = offset
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -3436,20 +3438,29 @@ class ExtractDiag(Op): ...@@ -3436,20 +3438,29 @@ class ExtractDiag(Op):
z[0] = z[0].copy() z[0] = z[0].copy()
def grad(self, inputs, gout): def grad(self, inputs, gout):
# Avoid circular import
from pytensor.tensor.subtensor import set_subtensor
(x,) = inputs (x,) = inputs
(gz,) = gout (gz,) = gout
if x.ndim == 2: axis1, axis2, offset = self.axis1, self.axis2, self.offset
x = zeros_like(x)
xdiag = AllocDiag(offset=self.offset)(gz) # Start with zeros (and axes in the front)
return [ x_grad = zeros_like(moveaxis(x, (axis1, axis2), (0, 1)))
pytensor.tensor.subtensor.set_subtensor(
x[: xdiag.shape[0], : xdiag.shape[1]], xdiag # Fill zeros with output diagonal
) xdiag = AllocDiag(offset=0, axis1=0, axis2=1)(gz)
] z_len = xdiag.shape[0]
if offset >= 0:
diag_slices = (slice(None, z_len), slice(offset, offset + z_len))
else: else:
warnings.warn("Gradient of ExtractDiag only works for matrices.") diag_slices = (slice(abs(offset), abs(offset) + z_len), slice(None, z_len))
return [grad_not_implemented(self, 0, x)] x_grad = set_subtensor(x_grad[diag_slices], xdiag)
# Put axes back in their original positions
x_grad = moveaxis(x_grad, (0, 1), (axis1, axis2))
return [x_grad]
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
from pytensor.tensor.math import clip, minimum from pytensor.tensor.math import clip, minimum
...@@ -3514,10 +3525,7 @@ def diagonal(a, offset=0, axis1=0, axis2=1): ...@@ -3514,10 +3525,7 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
class AllocDiag(Op): class AllocDiag(Op):
"""An `Op` that copies a vector to the diagonal of an empty matrix. """An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
It does the inverse of `ExtractDiag`.
"""
__props__ = ("offset", "axis1", "axis2") __props__ = ("offset", "axis1", "axis2")
......
...@@ -3552,16 +3552,10 @@ class TestDiag: ...@@ -3552,16 +3552,10 @@ class TestDiag:
""" """
Test that linalg.diag has the same behavior as numpy.diag. Test that linalg.diag has the same behavior as numpy.diag.
numpy.diag has two behaviors: numpy.diag has two behaviors:
(1) when given a vector, it returns a matrix with that vector as the (1) when given a vector, it returns a matrix with that vector as the diagonal.
diagonal. (2) when given a matrix, returns a vector which is the diagonal of the matrix.
(2) when given a matrix, returns a vector which is the diagonal of the
matrix.
(1) and (2) are tested by test_alloc_diag and test_extract_diag (1) and (2) are further tested by TestAllocDiag and TestExtractDiag, respectively.
respectively.
test_diag test makes sure that linalg.diag instantiates
the right op based on the dimension of the input.
""" """
def setup_method(self): def setup_method(self):
...@@ -3571,6 +3565,7 @@ class TestDiag: ...@@ -3571,6 +3565,7 @@ class TestDiag:
self.type = TensorType self.type = TensorType
def test_diag(self): def test_diag(self):
"""Makes sure that diag instantiates the right op based on the dimension of the input."""
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
# test vector input # test vector input
...@@ -3609,38 +3604,55 @@ class TestDiag: ...@@ -3609,38 +3604,55 @@ class TestDiag:
f = function([], g) f = function([], g)
assert np.array_equal(f(), np.diag(xx)) assert np.array_equal(f(), np.diag(xx))
def test_infer_shape(self):
class TestExtractDiag:
@pytest.mark.parametrize("axis1, axis2", [(0, 1), (1, 0)])
@pytest.mark.parametrize("offset", (-1, 0, 2))
def test_infer_shape(self, offset, axis1, axis2):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
x = vector() x = matrix("x")
g = diag(x) g = ExtractDiag(offset=offset, axis1=axis1, axis2=axis2)(x)
f = pytensor.function([x], g.shape)
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert sum(isinstance(node.op, AllocDiag) for node in topo) == 0
for shp in [5, 0, 1]:
m = rng.random(shp).astype(self.floatX)
assert (f(m) == np.diag(m).shape).all()
x = matrix()
g = diag(x)
f = pytensor.function([x], g.shape) f = pytensor.function([x], g.shape)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE": if config.mode != "FAST_COMPILE":
assert sum(isinstance(node.op, ExtractDiag) for node in topo) == 0 assert sum(isinstance(node.op, ExtractDiag) for node in topo) == 0
for shp in [(5, 3), (3, 5), (5, 1), (1, 5), (5, 0), (0, 5), (1, 0), (0, 1)]: for shp in [(5, 3), (3, 5), (5, 1), (1, 5), (5, 0), (0, 5), (1, 0), (0, 1)]:
m = rng.random(shp).astype(self.floatX) m = rng.random(shp).astype(config.floatX)
assert (f(m) == np.diag(m).shape).all() assert (
f(m) == np.diagonal(m, offset=offset, axis1=axis1, axis2=axis2).shape
).all()
def test_diag_grad(self): @pytest.mark.parametrize("axis1, axis2", [(0, 1), (1, 0)])
@pytest.mark.parametrize("offset", (0, 1, -1))
def test_grad_2d(self, offset, axis1, axis2):
diag_fn = ExtractDiag(offset=offset, axis1=axis1, axis2=axis2)
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
x = rng.random(5)
utt.verify_grad(diag, [x], rng=rng)
x = rng.random((5, 3)) x = rng.random((5, 3))
utt.verify_grad(diag, [x], rng=rng) utt.verify_grad(diag_fn, [x], rng=rng)
@pytest.mark.parametrize(
"axis1, axis2",
[
(0, 1),
(1, 0),
(1, 2),
(2, 1),
(0, 2),
(2, 0),
],
)
@pytest.mark.parametrize("offset", (0, 1, -1))
def test_grad_3d(self, offset, axis1, axis2):
diag_fn = ExtractDiag(offset=offset, axis1=axis1, axis2=axis2)
rng = np.random.default_rng(utt.fetch_seed())
x = rng.random((5, 4, 3))
utt.verify_grad(diag_fn, [x], rng=rng)
class TestAllocDiag: class TestAllocDiag:
# TODO: Separate perform, grad and infer_shape tests
def setup_method(self): def setup_method(self):
self.alloc_diag = AllocDiag self.alloc_diag = AllocDiag
self.mode = pytensor.compile.mode.get_default_mode() self.mode = pytensor.compile.mode.get_default_mode()
...@@ -3674,7 +3686,7 @@ class TestAllocDiag: ...@@ -3674,7 +3686,7 @@ class TestAllocDiag:
(-2, 0, 1), (-2, 0, 1),
(-1, 1, 2), (-1, 1, 2),
]: ]:
# Test AllocDiag values # Test perform
if np.maximum(axis1, axis2) > len(test_val.shape): if np.maximum(axis1, axis2) > len(test_val.shape):
continue continue
adiag_op = self.alloc_diag(offset=offset, axis1=axis1, axis2=axis2) adiag_op = self.alloc_diag(offset=offset, axis1=axis1, axis2=axis2)
...@@ -3688,7 +3700,6 @@ class TestAllocDiag: ...@@ -3688,7 +3700,6 @@ class TestAllocDiag:
# Test infer_shape # Test infer_shape
f_shape = pytensor.function([x], adiag_op(x).shape, mode="FAST_RUN") f_shape = pytensor.function([x], adiag_op(x).shape, mode="FAST_RUN")
# pytensor.printing.debugprint(f_shape.maker.fgraph.outputs[0])
output_shape = f_shape(test_val) output_shape = f_shape(test_val)
assert not any( assert not any(
isinstance(node.op, self.alloc_diag) isinstance(node.op, self.alloc_diag)
...@@ -3699,6 +3710,7 @@ class TestAllocDiag: ...@@ -3699,6 +3710,7 @@ class TestAllocDiag:
).shape ).shape
assert np.all(rediag_shape == test_val.shape) assert np.all(rediag_shape == test_val.shape)
# Test grad
diag_x = adiag_op(x) diag_x = adiag_op(x)
sum_diag_x = at_sum(diag_x) sum_diag_x = at_sum(diag_x)
grad_x = pytensor.grad(sum_diag_x, x) grad_x = pytensor.grad(sum_diag_x, x)
...@@ -3710,7 +3722,6 @@ class TestAllocDiag: ...@@ -3710,7 +3722,6 @@ class TestAllocDiag:
true_grad_input = np.diagonal( true_grad_input = np.diagonal(
grad_diag_input, offset=offset, axis1=axis1, axis2=axis2 grad_diag_input, offset=offset, axis1=axis1, axis2=axis2
) )
assert np.all(true_grad_input == grad_input) assert np.all(true_grad_input == grad_input)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论