提交 71c58f39 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Deprecate AllocDiag Op in favor of equivalent PyTensor graph

上级 deea8dd3
...@@ -8,7 +8,6 @@ from pytensor.link.jax.dispatch.basic import jax_funcify ...@@ -8,7 +8,6 @@ from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocDiag,
AllocEmpty, AllocEmpty,
ARange, ARange,
ExtractDiag, ExtractDiag,
...@@ -32,16 +31,6 @@ by JAX. An example of a graph that can be compiled to JAX: ...@@ -32,16 +31,6 @@ by JAX. An example of a graph that can be compiled to JAX:
""" """
@jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op, **kwargs):
offset = op.offset
def allocdiag(v, offset=offset):
return jnp.diag(v, k=offset)
return allocdiag
@jax_funcify.register(AllocEmpty) @jax_funcify.register(AllocEmpty)
def jax_funcify_AllocEmpty(op, **kwargs): def jax_funcify_AllocEmpty(op, **kwargs):
def allocempty(*shape): def allocempty(*shape):
......
...@@ -7,7 +7,6 @@ from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcif ...@@ -7,7 +7,6 @@ from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcif
from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocDiag,
AllocEmpty, AllocEmpty,
ARange, ARange,
ExtractDiag, ExtractDiag,
...@@ -93,17 +92,6 @@ def alloc(val, {", ".join(shape_var_names)}): ...@@ -93,17 +92,6 @@ def alloc(val, {", ".join(shape_var_names)}):
return numba_basic.numba_njit(alloc_fn) return numba_basic.numba_njit(alloc_fn)
@numba_funcify.register(AllocDiag)
def numba_funcify_AllocDiag(op, **kwargs):
offset = op.offset
@numba_basic.numba_njit(inline="always")
def allocdiag(v):
return np.diag(v, k=offset)
return allocdiag
@numba_funcify.register(ARange) @numba_funcify.register(ARange)
def numba_funcify_ARange(op, **kwargs): def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
......
...@@ -6,6 +6,7 @@ manipulation of tensors. ...@@ -6,6 +6,7 @@ manipulation of tensors.
""" """
import builtins import builtins
import warnings
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
...@@ -3450,7 +3451,7 @@ class ExtractDiag(Op): ...@@ -3450,7 +3451,7 @@ class ExtractDiag(Op):
x_grad = zeros_like(moveaxis(x, (axis1, axis2), (0, 1))) x_grad = zeros_like(moveaxis(x, (axis1, axis2), (0, 1)))
# Fill zeros with output diagonal # Fill zeros with output diagonal
xdiag = AllocDiag(offset=0, axis1=0, axis2=1)(gz) xdiag = alloc_diag(gz, offset=0, axis1=0, axis2=1)
z_len = xdiag.shape[0] z_len = xdiag.shape[0]
if offset >= 0: if offset >= 0:
diag_slices = (slice(None, z_len), slice(offset, offset + z_len)) diag_slices = (slice(None, z_len), slice(offset, offset + z_len))
...@@ -3544,6 +3545,10 @@ class AllocDiag(Op): ...@@ -3544,6 +3545,10 @@ class AllocDiag(Op):
Axis to be used as the second axis of the 2-D sub-arrays to which Axis to be used as the second axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to second axis (i.e. 1). the diagonals will be allocated. Defaults to second axis (i.e. 1).
""" """
warnings.warn(
"AllocDiag is deprecated. Use `alloc_diag` instead",
FutureWarning,
)
self.offset = offset self.offset = offset
if axis1 < 0 or axis2 < 0: if axis1 < 0 or axis2 < 0:
raise NotImplementedError("AllocDiag does not support negative axis") raise NotImplementedError("AllocDiag does not support negative axis")
...@@ -3625,6 +3630,43 @@ class AllocDiag(Op): ...@@ -3625,6 +3630,43 @@ class AllocDiag(Op):
self.axis2 = 1 self.axis2 = 1
def alloc_diag(diag, offset=0, axis1=0, axis2=1):
"""Insert a vector on the diagonal of a zero-ed matrix.
diagonal(alloc_diag(x)) == x
"""
from pytensor.tensor import set_subtensor
diag = as_tensor_variable(diag)
axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1)
if axis1 > axis2:
axis1, axis2 = axis2, axis1
# Create array with one extra dimension for resulting matrix
result_shape = tuple(diag.shape)[:-1] + (diag.shape[-1] + abs(offset),) * 2
result = zeros(result_shape, dtype=diag.dtype)
# Create slice for diagonal in final 2 axes
idxs = arange(diag.shape[-1])
diagonal_slice = (slice(None),) * (len(result_shape) - 2) + (
idxs + np.maximum(0, -offset),
idxs + np.maximum(0, offset),
)
# Fill in final 2 axes with diag
result = set_subtensor(result[diagonal_slice], diag)
if diag.type.ndim > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(diag.type.ndim - 1))
last_idx = axes[-1]
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)
return result
def diag(v, k=0): def diag(v, k=0):
""" """
A helper function for two ops: `ExtractDiag` and A helper function for two ops: `ExtractDiag` and
...@@ -3650,7 +3692,7 @@ def diag(v, k=0): ...@@ -3650,7 +3692,7 @@ def diag(v, k=0):
_v = as_tensor_variable(v) _v = as_tensor_variable(v)
if _v.ndim == 1: if _v.ndim == 1:
return AllocDiag(k)(_v) return alloc_diag(_v, offset=k)
elif _v.ndim == 2: elif _v.ndim == 2:
return diagonal(_v, offset=k) return diagonal(_v, offset=k)
else: else:
......
...@@ -85,7 +85,7 @@ def test_jax_basic(): ...@@ -85,7 +85,7 @@ def test_jax_basic():
], ],
) )
out = at.diag(b) out = at.diag(at.specify_shape(b, shape=(10,)))
out_fg = FunctionGraph([b], [out]) out_fg = FunctionGraph([b], [out])
compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])
......
...@@ -57,28 +57,6 @@ def test_AllocEmpty(): ...@@ -57,28 +57,6 @@ def test_AllocEmpty():
compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype) compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype)
@pytest.mark.parametrize(
"v, offset",
[
(set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 0),
(set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 1),
(set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), -1),
],
)
def test_AllocDiag(v, offset):
g = atb.AllocDiag(offset=offset)(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v", [set_test_value(aes.float64(), np.array(1.0, dtype="float64"))] "v", [set_test_value(aes.float64(), np.array(1.0, dtype="float64"))]
) )
......
...@@ -23,7 +23,6 @@ from pytensor.scalar import autocast_float, autocast_float_as ...@@ -23,7 +23,6 @@ from pytensor.scalar import autocast_float, autocast_float_as
from pytensor.tensor import NoneConst from pytensor.tensor import NoneConst
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocDiag,
AllocEmpty, AllocEmpty,
ARange, ARange,
Choose, Choose,
...@@ -92,7 +91,7 @@ from pytensor.tensor.elemwise import DimShuffle ...@@ -92,7 +91,7 @@ from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import dense_dot from pytensor.tensor.math import dense_dot
from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.shape import Reshape, Shape, Shape_i, shape_padright, specify_shape from pytensor.tensor.shape import Reshape, Shape_i, shape_padright, specify_shape
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
bscalar, bscalar,
...@@ -3571,7 +3570,6 @@ class TestDiag: ...@@ -3571,7 +3570,6 @@ class TestDiag:
# test vector input # test vector input
x = vector() x = vector()
g = diag(x) g = diag(x)
assert isinstance(g.owner.op, AllocDiag)
f = pytensor.function([x], g) f = pytensor.function([x], g)
for shp in [5, 0, 1]: for shp in [5, 0, 1]:
m = rng.random(shp).astype(self.floatX) m = rng.random(shp).astype(self.floatX)
...@@ -3654,10 +3652,6 @@ class TestExtractDiag: ...@@ -3654,10 +3652,6 @@ class TestExtractDiag:
class TestAllocDiag: class TestAllocDiag:
# TODO: Separate perform, grad and infer_shape tests # TODO: Separate perform, grad and infer_shape tests
def setup_method(self):
self.alloc_diag = AllocDiag
self.mode = pytensor.compile.mode.get_default_mode()
def _generator(self): def _generator(self):
dims = 4 dims = 4
shape = (5,) * dims shape = (5,) * dims
...@@ -3690,34 +3684,28 @@ class TestAllocDiag: ...@@ -3690,34 +3684,28 @@ class TestAllocDiag:
# Test perform # Test perform
if np.maximum(axis1, axis2) > len(test_val.shape): if np.maximum(axis1, axis2) > len(test_val.shape):
continue continue
adiag_op = self.alloc_diag(offset=offset, axis1=axis1, axis2=axis2) diag_x = at.alloc_diag(x, offset=offset, axis1=axis1, axis2=axis2)
f = pytensor.function([x], adiag_op(x)) f = pytensor.function([x], diag_x)
# AllocDiag and extract the diagonal again # alloc_diag and extract the diagonal again to check for correctness
# to check
diag_arr = f(test_val) diag_arr = f(test_val)
rediag = np.diagonal(diag_arr, offset=offset, axis1=axis1, axis2=axis2) rediag = np.diagonal(diag_arr, offset=offset, axis1=axis1, axis2=axis2)
assert np.all(rediag == test_val) assert np.all(rediag == test_val)
# Test infer_shape # Test infer_shape
f_shape = pytensor.function([x], adiag_op(x).shape, mode="FAST_RUN") f_shape = pytensor.function([x], diag_x.shape, mode="FAST_RUN")
output_shape = f_shape(test_val) output_shape = f_shape(test_val)
assert not any(
isinstance(node.op, self.alloc_diag)
for node in f_shape.maker.fgraph.toposort()
)
rediag_shape = np.diagonal( rediag_shape = np.diagonal(
np.ones(output_shape), offset=offset, axis1=axis1, axis2=axis2 np.ones(output_shape), offset=offset, axis1=axis1, axis2=axis2
).shape ).shape
assert np.all(rediag_shape == test_val.shape) assert np.all(rediag_shape == test_val.shape)
# Test grad # Test grad
diag_x = adiag_op(x)
sum_diag_x = at_sum(diag_x) sum_diag_x = at_sum(diag_x)
grad_x = pytensor.grad(sum_diag_x, x) grad_x = pytensor.grad(sum_diag_x, x)
grad_diag_x = pytensor.grad(sum_diag_x, diag_x) grad_diag_x = pytensor.grad(sum_diag_x, diag_x)
f_grad_x = pytensor.function([x], grad_x, mode=self.mode) f_grad_x = pytensor.function([x], grad_x)
f_grad_diag_x = pytensor.function([x], grad_diag_x, mode=self.mode) f_grad_diag_x = pytensor.function([x], grad_diag_x)
grad_input = f_grad_x(test_val) grad_input = f_grad_x(test_val)
grad_diag_input = f_grad_diag_x(test_val) grad_diag_input = f_grad_diag_x(test_val)
true_grad_input = np.diagonal( true_grad_input = np.diagonal(
...@@ -3894,20 +3882,6 @@ class TestInferShape(utt.InferShapeTester): ...@@ -3894,20 +3882,6 @@ class TestInferShape(utt.InferShapeTester):
atens3_diag = ExtractDiag(1, 2, 0)(atens3) atens3_diag = ExtractDiag(1, 2, 0)(atens3)
self._compile_and_check([atens3], [atens3_diag], [atens3_val], ExtractDiag) self._compile_and_check([atens3], [atens3_diag], [atens3_val], ExtractDiag)
def test_AllocDiag(self):
advec = dvector()
advec_val = random(4)
self._compile_and_check([advec], [AllocDiag()(advec)], [advec_val], AllocDiag)
# Shape
# 'opt.Makevector' precludes optimizer from disentangling
# elements of shape
adtens = tensor3()
adtens_val = random(4, 5, 3)
self._compile_and_check(
[adtens], [Shape()(adtens)], [adtens_val], (MakeVector, Shape)
)
def test_Split(self): def test_Split(self):
aiscal = iscalar() aiscal = iscalar()
aivec = ivector() aivec = ivector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论