Unverified 提交 981688c3 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Implement `pad` (#748)

* Add `pt.pad` * Refactor linspace, logspace, and geomspace to match numpy implementation * Add `pt.flip` * Move `flip` to `tensor/subtensor.py`, add docstring * Move `slice_at_axis` to `tensor/subtensor` and expose it in `pytensor.tensor`
上级 f489cf4b
......@@ -6,6 +6,7 @@ import pytensor.link.jax.dispatch.blas
import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.pad
import pytensor.link.jax.dispatch.math
import pytensor.link.jax.dispatch.nlinalg
import pytensor.link.jax.dispatch.random
......
import jax.numpy as jnp
import numpy as np
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.pad import Pad
@jax_funcify.register(Pad)
def jax_funcify_pad(op, **kwargs):
pad_mode = op.pad_mode
reflect_type = op.reflect_type
has_stat_length = op.has_stat_length
if pad_mode == "constant":
def constant_pad(x, pad_width, constant_values):
return jnp.pad(x, pad_width, mode=pad_mode, constant_values=constant_values)
return constant_pad
elif pad_mode == "linear_ramp":
def lr_pad(x, pad_width, end_values):
# JAX does not allow a dynamic input if end_values is non-scalar
if not isinstance(end_values, int | float):
end_values = tuple(np.array(end_values))
return jnp.pad(x, pad_width, mode=pad_mode, end_values=end_values)
return lr_pad
elif pad_mode in ["maximum", "minimum", "mean"] and has_stat_length:
def stat_pad(x, pad_width, stat_length):
# JAX does not allow a dynamic input here, need to cast to tuple
return jnp.pad(
x, pad_width, mode=pad_mode, stat_length=tuple(np.array(stat_length))
)
return stat_pad
elif pad_mode in ["reflect", "symmetric"]:
def loop_pad(x, pad_width):
return jnp.pad(x, pad_width, mode=pad_mode, reflect_type=reflect_type)
return loop_pad
else:
def pad(x, pad_width):
return jnp.pad(x, pad_width, mode=pad_mode)
return pad
......@@ -130,6 +130,7 @@ from pytensor.tensor.blas import batched_dot, batched_tensordot
from pytensor.tensor.extra_ops import *
from pytensor.tensor.io import *
from pytensor.tensor.math import *
from pytensor.tensor.pad import pad
from pytensor.tensor.shape import (
reshape,
shape,
......
差异被折叠。
......@@ -3013,8 +3013,123 @@ def _get_vector_length_Subtensor(op, var):
raise ValueError(f"Length of {var} cannot be determined")
def slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]:
"""
Construct tuple of slices to slice an array in the given dimension.
Copied from numpy.lib.arraypad._slice_at_axis
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L33
Parameters
----------
sl : slice
The slice for the given dimension.
axis : int
The axis to which `sl` is applied. All other dimensions are left
"unsliced".
Returns
-------
sl : tuple of slices
A tuple with slices matching `shape` in length.
Examples
--------
.. testcode::
import pytensor.tensor as pt
s = pt.slice_at_axis(slice(None, 1), 1)
print(s)
.. testoutput::
(slice(None, None, None), slice(None, 1, None), Ellipsis)
.. testcode::
x = pt.tensor('x', shape=(None, None, None))
x_sliced = x[s]
f = pytensor.function([x], x_sliced)
x = np.arange(27).reshape(3, 3, 3)
print(f(x))
.. testoutput::
[[[ 0. 1. 2.]]
[[ 9. 10. 11.]]
[[18. 19. 20.]]]
"""
if axis >= 0:
return (slice(None),) * axis + (sl,) + (...,) # type: ignore
else:
# If axis = -1 we want zero right padding (and so on), so subtract one
axis = abs(axis) - 1
return (...,) + (sl,) + (slice(None),) * axis # type: ignore
def flip(
arr: TensorVariable, axis: int | tuple[int] | TensorVariable | None = None
) -> TensorVariable:
"""
Reverse the order of elements in an tensor along the given axis.
Parameters
----------
arr: TensorVariable
Input tensor.
axis: int | tuple[int] | TensorVariable, optional
Axis or axes along which to flip over. The default is to flip over all of the axes of the input tensor.
Returns
-------
arr: TensorVariable
A view of `arr` with the entries of axis reversed.
Examples
--------
.. testcode::
import pytensor
import pytensor.tensor as pt
x = pt.tensor('x', shape=(None, None))
x_flipped = pt.flip(x, axis=0)
f = pytensor.function([x], x_flipped)
x = [[1, 2], [3, 4]]
print(f(x))
.. testoutput::
[[3. 4.]
[1. 2.]]
"""
if axis is None:
index = ((slice(None, None, -1)),) * arr.ndim
else:
if isinstance(axis, int):
axis = (axis,)
index = tuple(
[
slice(None, None, -1) if i in axis else slice(None, None, None)
for i in range(arr.ndim)
]
)
return cast(TensorVariable, arr[index])
__all__ = [
"take",
"flip",
"slice_at_axis",
"inc_subtensor",
"set_subtensor",
]
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor.pad import PadMode
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
floatX = config.floatX
RTOL = ATOL = 1e-6 if floatX.endswith("64") else 1e-3
@pytest.mark.parametrize(
"mode, kwargs",
[
("constant", {"constant_values": 0}),
("constant", {"constant_values": (1, 2)}),
("edge", {}),
("linear_ramp", {"end_values": 0}),
("linear_ramp", {"end_values": (1, 2)}),
("reflect", {"reflect_type": "even"}),
("wrap", {}),
("symmetric", {"reflect_type": "even"}),
("mean", {"stat_length": None}),
("mean", {"stat_length": (10, 2)}),
("maximum", {"stat_length": None}),
("maximum", {"stat_length": (10, 2)}),
("minimum", {"stat_length": None}),
("minimum", {"stat_length": (10, 2)}),
],
ids=[
"constant_default",
"constant_tuple",
"edge",
"linear_ramp_default",
"linear_ramp_tuple",
"reflect",
"wrap",
"symmetric",
"mean_default",
"mean_tuple",
"maximum_default",
"maximum_tuple",
"minimum_default",
"minimum_tuple",
],
)
def test_jax_pad(mode: PadMode, kwargs):
x_pt = pt.tensor("x", shape=(3, 3))
x = np.random.normal(size=(3, 3))
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
res_fg = FunctionGraph([x_pt], [res])
compare_jax_and_py(
res_fg,
[x],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
py_mode="FAST_RUN",
)
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor.pad import PadMode
from tests.link.numba.test_basic import compare_numba_and_py
floatX = config.floatX
RTOL = ATOL = 1e-6 if floatX.endswith("64") else 1e-3
@pytest.mark.parametrize(
"mode, kwargs",
[
("constant", {"constant_values": 0}),
("constant", {"constant_values": (1, 2)}),
pytest.param(
"edge",
{},
marks=pytest.mark.skip(
"This is causing a segfault in NUMBA mode, but I have no idea why"
),
),
("linear_ramp", {"end_values": 0}),
("linear_ramp", {"end_values": (1, 2)}),
("reflect", {"reflect_type": "even"}),
("wrap", {}),
("symmetric", {"reflect_type": "even"}),
("mean", {"stat_length": None}),
("mean", {"stat_length": (10, 2)}),
("maximum", {"stat_length": None}),
("maximum", {"stat_length": (10, 2)}),
("minimum", {"stat_length": None}),
("minimum", {"stat_length": (10, 2)}),
],
ids=[
"constant_default",
"constant_tuple",
"edge",
"linear_ramp_default",
"linear_ramp_tuple",
"reflect",
"wrap",
"symmetric",
"mean_default",
"mean_tuple",
"maximum_default",
"maximum_tuple",
"minimum_default",
"minimum_tuple",
],
)
def test_numba_pad(mode: PadMode, kwargs):
x_pt = pt.tensor("x", shape=(3, 3))
x = np.random.normal(size=(3, 3))
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
res_fg = FunctionGraph([x_pt], [res])
compare_numba_and_py(
res_fg,
[x],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
py_mode="FAST_RUN",
)
......@@ -35,9 +35,6 @@ from pytensor.tensor.extra_ops import (
diff,
fill_diagonal,
fill_diagonal_offset,
geomspace,
linspace,
logspace,
ravel_multi_index,
repeat,
searchsorted,
......@@ -1281,25 +1278,37 @@ def test_broadcast_arrays():
@pytest.mark.parametrize(
"start, stop, num_samples",
"op",
["linspace", "logspace", "geomspace"],
ids=["linspace", "logspace", "geomspace"],
)
@pytest.mark.parametrize("dtype", [None, "int", "float"], ids=[None, "int", "float"])
@pytest.mark.parametrize(
"start, stop, num_samples, endpoint, axis",
[
(1, 10, 50),
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25),
(1, np.array([5, 6]), 30),
(1, 10, 50, True, 0),
(1, 10, 1, True, 0),
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 0),
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 1),
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, False, -1),
(1, np.array([5, 6]), 30, True, 0),
(1, np.array([5, 6]), 30, False, -1),
],
)
def test_space_ops(start, stop, num_samples):
z = linspace(start, stop, num_samples)
pytensor_res = function(inputs=[], outputs=z)()
numpy_res = np.linspace(start, stop, num=num_samples)
assert np.allclose(pytensor_res, numpy_res)
z = logspace(start, stop, num_samples)
pytensor_res = function(inputs=[], outputs=z)()
numpy_res = np.logspace(start, stop, num=num_samples)
assert np.allclose(pytensor_res, numpy_res)
z = geomspace(start, stop, num_samples)
pytensor_res = function(inputs=[], outputs=z)()
numpy_res = np.geomspace(start, stop, num=num_samples)
assert np.allclose(pytensor_res, numpy_res)
def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
pt_func = getattr(pt, op)
np_func = getattr(np, op)
dtype = dtype + config.floatX[-2:] if dtype is not None else dtype
z = pt_func(start, stop, num_samples, endpoint=endpoint, axis=axis, dtype=dtype)
numpy_res = np_func(
start, stop, num=num_samples, endpoint=endpoint, dtype=dtype, axis=axis
)
pytensor_res = function(inputs=[], outputs=z, mode="FAST_COMPILE")()
np.testing.assert_allclose(
pytensor_res,
numpy_res,
atol=1e-6 if config.floatX.endswith("64") else 1e-4,
rtol=1e-6 if config.floatX.endswith("64") else 1e-4,
)
from typing import Literal
import numpy as np
import pytest
import pytensor
from pytensor.tensor.pad import PadMode, pad
floatX = pytensor.config.floatX
RTOL = ATOL = 1e-8 if floatX.endswith("64") else 1e-4
def test_unknown_mode_raises():
x = np.random.normal(size=(3, 3)).astype(floatX)
with pytest.raises(ValueError, match="Invalid mode: unknown"):
pad(x, 1, mode="unknown")
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 3, 3)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize("constant", [0, 0.0], ids=["int", "float"])
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
def test_constant_pad(
size: tuple, constant: int | float, pad_width: int | tuple[int, ...]
):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="constant", constant_values=constant)
z = pad(x, pad_width, mode="constant", constant_values=constant)
assert z.owner.op.pad_mode == "constant"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
def test_edge_pad(size: tuple, pad_width: int | tuple[int, ...]):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="edge")
z = pad(x, pad_width, mode="edge")
assert z.owner.op.pad_mode == "edge"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
@pytest.mark.parametrize("end_values", [0, -1], ids=["0", "-1"])
def test_linear_ramp_pad(
size: tuple,
pad_width: int | tuple[int, ...],
end_values: int | float | tuple[int | float, ...],
):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="linear_ramp", end_values=end_values)
z = pad(x, pad_width, mode="linear_ramp", end_values=end_values)
assert z.owner.op.pad_mode == "linear_ramp"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
@pytest.mark.parametrize("stat", ["mean", "minimum", "maximum"])
@pytest.mark.parametrize("stat_length", [None, 2])
def test_stat_pad(
size: tuple,
pad_width: int | tuple[int, ...],
stat: PadMode,
stat_length: int | None,
):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode=stat, stat_length=stat_length)
z = pad(x, pad_width, mode=stat, stat_length=stat_length)
assert z.owner.op.pad_mode == stat
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
def test_wrap_pad(size: tuple, pad_width: int | tuple[int, ...]):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="wrap")
z = pad(x, pad_width, mode="wrap")
assert z.owner.op.pad_mode == "wrap"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
@pytest.mark.parametrize(
"reflect_type",
["even", pytest.param("odd", marks=pytest.mark.xfail(raises=NotImplementedError))],
ids=["even", "odd"],
)
def test_symmetric_pad(
size,
pad_width,
reflect_type: Literal["even", "odd"],
):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="symmetric", reflect_type=reflect_type)
z = pad(x, pad_width, mode="symmetric", reflect_type=reflect_type)
assert z.owner.op.pad_mode == "symmetric"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
@pytest.mark.parametrize(
"reflect_type",
["even", pytest.param("odd", marks=pytest.mark.xfail(raises=NotImplementedError))],
ids=["even", "odd"],
)
def test_reflect_pad(
size,
pad_width,
reflect_type: Literal["even", "odd"],
):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="reflect", reflect_type=reflect_type)
z = pad(x, pad_width, mode="reflect", reflect_type=reflect_type)
assert z.owner.op.pad_mode == "reflect"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"mode",
[
"constant",
"edge",
"linear_ramp",
"wrap",
"symmetric",
"reflect",
"mean",
"maximum",
"minimum",
],
)
@pytest.mark.parametrize("padding", ["symmetric", "asymmetric"])
def test_nd_padding(mode, padding):
rng = np.random.default_rng()
n = rng.integers(3, 5)
if padding == "symmetric":
pad_width = [(i, i) for i in rng.integers(1, 5, size=n)]
stat_length = [(i, i) for i in rng.integers(1, 5, size=n)]
else:
pad_width = rng.integers(1, 5, size=(n, 2)).tolist()
stat_length = rng.integers(1, 5, size=(n, 2)).tolist()
test_kwargs = {
"constant": {"constant_values": 0},
"linear_ramp": {"end_values": 0},
"maximum": {"stat_length": stat_length},
"mean": {"stat_length": stat_length},
"minimum": {"stat_length": stat_length},
"reflect": {"reflect_type": "even"},
"symmetric": {"reflect_type": "even"},
}
x = np.random.normal(size=(2,) * n).astype(floatX)
kwargs = test_kwargs.get(mode, {})
expected = np.pad(x, pad_width, mode=mode, **kwargs)
z = pad(x, pad_width, mode=mode, **kwargs)
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
......@@ -37,11 +37,13 @@ from pytensor.tensor.subtensor import (
advanced_subtensor1,
as_index_literal,
basic_shape,
flip,
get_canonical_form_slice,
inc_subtensor,
index_vars_to_types,
indexed_result_shape,
set_subtensor,
slice_at_axis,
take,
)
from pytensor.tensor.type import (
......@@ -2902,3 +2904,39 @@ def test_vectorize_adv_subtensor(
vectorize_pt(x_test, idx_test),
vectorize_np(x_test, idx_test),
)
def test_slice_at_axis():
x = ptb.tensor("x", shape=(3, 4, 5))
x_sliced = x[slice_at_axis(slice(None, 1), axis=0)]
assert x_sliced.type.shape == (1, 4, 5)
# Negative axis
x_sliced = x[slice_at_axis(slice(None, 1), axis=-2)]
assert x_sliced.type.shape == (3, 1, 5)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
def test_flip(size: tuple[int]):
from itertools import combinations
ATOL = RTOL = 1e-8 if config.floatX == "float64" else 1e-4
x = np.random.normal(size=size).astype(config.floatX)
x_pt = pytensor.tensor.tensor(shape=size, name="x")
expected = np.flip(x, axis=None)
z = flip(x_pt, axis=None)
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
# Test all combinations of axes
flip_options = [
axes for i in range(1, x.ndim + 1) for axes in combinations(range(x.ndim), r=i)
]
for axes in flip_options:
expected = np.flip(x, axis=list(axes))
z = flip(x_pt, axis=list(axes))
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论