提交 71c58f39 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Deprecate AllocDiag Op in favor of equivalent PyTensor graph

上级 deea8dd3
......@@ -8,7 +8,6 @@ from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import (
Alloc,
AllocDiag,
AllocEmpty,
ARange,
ExtractDiag,
......@@ -32,16 +31,6 @@ by JAX. An example of a graph that can be compiled to JAX:
"""
@jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op, **kwargs):
offset = op.offset
def allocdiag(v, offset=offset):
return jnp.diag(v, k=offset)
return allocdiag
@jax_funcify.register(AllocEmpty)
def jax_funcify_AllocEmpty(op, **kwargs):
def allocempty(*shape):
......
......@@ -7,7 +7,6 @@ from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcif
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor.basic import (
Alloc,
AllocDiag,
AllocEmpty,
ARange,
ExtractDiag,
......@@ -93,17 +92,6 @@ def alloc(val, {", ".join(shape_var_names)}):
return numba_basic.numba_njit(alloc_fn)
@numba_funcify.register(AllocDiag)
def numba_funcify_AllocDiag(op, **kwargs):
offset = op.offset
@numba_basic.numba_njit(inline="always")
def allocdiag(v):
return np.diag(v, k=offset)
return allocdiag
@numba_funcify.register(ARange)
def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype)
......
......@@ -6,6 +6,7 @@ manipulation of tensors.
"""
import builtins
import warnings
from functools import partial
from numbers import Number
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
......@@ -3450,7 +3451,7 @@ class ExtractDiag(Op):
x_grad = zeros_like(moveaxis(x, (axis1, axis2), (0, 1)))
# Fill zeros with output diagonal
xdiag = AllocDiag(offset=0, axis1=0, axis2=1)(gz)
xdiag = alloc_diag(gz, offset=0, axis1=0, axis2=1)
z_len = xdiag.shape[0]
if offset >= 0:
diag_slices = (slice(None, z_len), slice(offset, offset + z_len))
......@@ -3544,6 +3545,10 @@ class AllocDiag(Op):
Axis to be used as the second axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to second axis (i.e. 1).
"""
warnings.warn(
"AllocDiag is deprecated. Use `alloc_diag` instead",
FutureWarning,
)
self.offset = offset
if axis1 < 0 or axis2 < 0:
raise NotImplementedError("AllocDiag does not support negative axis")
......@@ -3625,6 +3630,43 @@ class AllocDiag(Op):
self.axis2 = 1
def alloc_diag(diag, offset=0, axis1=0, axis2=1):
"""Insert a vector on the diagonal of a zero-ed matrix.
diagonal(alloc_diag(x)) == x
"""
from pytensor.tensor import set_subtensor
diag = as_tensor_variable(diag)
axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1)
if axis1 > axis2:
axis1, axis2 = axis2, axis1
# Create array with one extra dimension for resulting matrix
result_shape = tuple(diag.shape)[:-1] + (diag.shape[-1] + abs(offset),) * 2
result = zeros(result_shape, dtype=diag.dtype)
# Create slice for diagonal in final 2 axes
idxs = arange(diag.shape[-1])
diagonal_slice = (slice(None),) * (len(result_shape) - 2) + (
idxs + np.maximum(0, -offset),
idxs + np.maximum(0, offset),
)
# Fill in final 2 axes with diag
result = set_subtensor(result[diagonal_slice], diag)
if diag.type.ndim > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(diag.type.ndim - 1))
last_idx = axes[-1]
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)
return result
def diag(v, k=0):
"""
A helper function for two ops: `ExtractDiag` and
......@@ -3650,7 +3692,7 @@ def diag(v, k=0):
_v = as_tensor_variable(v)
if _v.ndim == 1:
return AllocDiag(k)(_v)
return alloc_diag(_v, offset=k)
elif _v.ndim == 2:
return diagonal(_v, offset=k)
else:
......
......@@ -85,7 +85,7 @@ def test_jax_basic():
],
)
out = at.diag(b)
out = at.diag(at.specify_shape(b, shape=(10,)))
out_fg = FunctionGraph([b], [out])
compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])
......
......@@ -57,28 +57,6 @@ def test_AllocEmpty():
compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype)
@pytest.mark.parametrize(
"v, offset",
[
(set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 0),
(set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 1),
(set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), -1),
],
)
def test_AllocDiag(v, offset):
g = atb.AllocDiag(offset=offset)(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"v", [set_test_value(aes.float64(), np.array(1.0, dtype="float64"))]
)
......
......@@ -23,7 +23,6 @@ from pytensor.scalar import autocast_float, autocast_float_as
from pytensor.tensor import NoneConst
from pytensor.tensor.basic import (
Alloc,
AllocDiag,
AllocEmpty,
ARange,
Choose,
......@@ -92,7 +91,7 @@ from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import dense_dot
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.shape import Reshape, Shape, Shape_i, shape_padright, specify_shape
from pytensor.tensor.shape import Reshape, Shape_i, shape_padright, specify_shape
from pytensor.tensor.type import (
TensorType,
bscalar,
......@@ -3571,7 +3570,6 @@ class TestDiag:
# test vector input
x = vector()
g = diag(x)
assert isinstance(g.owner.op, AllocDiag)
f = pytensor.function([x], g)
for shp in [5, 0, 1]:
m = rng.random(shp).astype(self.floatX)
......@@ -3654,10 +3652,6 @@ class TestExtractDiag:
class TestAllocDiag:
# TODO: Separate perform, grad and infer_shape tests
def setup_method(self):
self.alloc_diag = AllocDiag
self.mode = pytensor.compile.mode.get_default_mode()
def _generator(self):
dims = 4
shape = (5,) * dims
......@@ -3690,34 +3684,28 @@ class TestAllocDiag:
# Test perform
if np.maximum(axis1, axis2) > len(test_val.shape):
continue
adiag_op = self.alloc_diag(offset=offset, axis1=axis1, axis2=axis2)
f = pytensor.function([x], adiag_op(x))
# AllocDiag and extract the diagonal again
# to check
diag_x = at.alloc_diag(x, offset=offset, axis1=axis1, axis2=axis2)
f = pytensor.function([x], diag_x)
# alloc_diag and extract the diagonal again to check for correctness
diag_arr = f(test_val)
rediag = np.diagonal(diag_arr, offset=offset, axis1=axis1, axis2=axis2)
assert np.all(rediag == test_val)
# Test infer_shape
f_shape = pytensor.function([x], adiag_op(x).shape, mode="FAST_RUN")
f_shape = pytensor.function([x], diag_x.shape, mode="FAST_RUN")
output_shape = f_shape(test_val)
assert not any(
isinstance(node.op, self.alloc_diag)
for node in f_shape.maker.fgraph.toposort()
)
rediag_shape = np.diagonal(
np.ones(output_shape), offset=offset, axis1=axis1, axis2=axis2
).shape
assert np.all(rediag_shape == test_val.shape)
# Test grad
diag_x = adiag_op(x)
sum_diag_x = at_sum(diag_x)
grad_x = pytensor.grad(sum_diag_x, x)
grad_diag_x = pytensor.grad(sum_diag_x, diag_x)
f_grad_x = pytensor.function([x], grad_x, mode=self.mode)
f_grad_diag_x = pytensor.function([x], grad_diag_x, mode=self.mode)
f_grad_x = pytensor.function([x], grad_x)
f_grad_diag_x = pytensor.function([x], grad_diag_x)
grad_input = f_grad_x(test_val)
grad_diag_input = f_grad_diag_x(test_val)
true_grad_input = np.diagonal(
......@@ -3894,20 +3882,6 @@ class TestInferShape(utt.InferShapeTester):
atens3_diag = ExtractDiag(1, 2, 0)(atens3)
self._compile_and_check([atens3], [atens3_diag], [atens3_val], ExtractDiag)
def test_AllocDiag(self):
advec = dvector()
advec_val = random(4)
self._compile_and_check([advec], [AllocDiag()(advec)], [advec_val], AllocDiag)
# Shape
# 'opt.Makevector' precludes optimizer from disentangling
# elements of shape
adtens = tensor3()
adtens_val = random(4, 5, 3)
self._compile_and_check(
[adtens], [Shape()(adtens)], [adtens_val], (MakeVector, Shape)
)
def test_Split(self):
aiscal = iscalar()
aivec = ivector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论