Unverified 提交 981688c3 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Implement `pad` (#748)

* Add `pt.pad` * Refactor linspace, logspace, and geomspace to match numpy implementation * Add `pt.flip` * Move `flip` to `tensor/subtensor.py`, add docstring * Move `slice_at_axis` to `tensor/subtensor` and expose it in `pytensor.tensor`
上级 f489cf4b
......@@ -6,6 +6,7 @@ import pytensor.link.jax.dispatch.blas
import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.pad
import pytensor.link.jax.dispatch.math
import pytensor.link.jax.dispatch.nlinalg
import pytensor.link.jax.dispatch.random
......
import jax.numpy as jnp
import numpy as np
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.pad import Pad
@jax_funcify.register(Pad)
def jax_funcify_pad(op, **kwargs):
pad_mode = op.pad_mode
reflect_type = op.reflect_type
has_stat_length = op.has_stat_length
if pad_mode == "constant":
def constant_pad(x, pad_width, constant_values):
return jnp.pad(x, pad_width, mode=pad_mode, constant_values=constant_values)
return constant_pad
elif pad_mode == "linear_ramp":
def lr_pad(x, pad_width, end_values):
# JAX does not allow a dynamic input if end_values is non-scalar
if not isinstance(end_values, int | float):
end_values = tuple(np.array(end_values))
return jnp.pad(x, pad_width, mode=pad_mode, end_values=end_values)
return lr_pad
elif pad_mode in ["maximum", "minimum", "mean"] and has_stat_length:
def stat_pad(x, pad_width, stat_length):
# JAX does not allow a dynamic input here, need to cast to tuple
return jnp.pad(
x, pad_width, mode=pad_mode, stat_length=tuple(np.array(stat_length))
)
return stat_pad
elif pad_mode in ["reflect", "symmetric"]:
def loop_pad(x, pad_width):
return jnp.pad(x, pad_width, mode=pad_mode, reflect_type=reflect_type)
return loop_pad
else:
def pad(x, pad_width):
return jnp.pad(x, pad_width, mode=pad_mode)
return pad
......@@ -130,6 +130,7 @@ from pytensor.tensor.blas import batched_dot, batched_tensordot
from pytensor.tensor.extra_ops import *
from pytensor.tensor.io import *
from pytensor.tensor.math import *
from pytensor.tensor.pad import pad
from pytensor.tensor.shape import (
reshape,
shape,
......
import warnings
from collections.abc import Collection, Iterable
import numpy as np
......@@ -20,14 +21,24 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import alloc, second
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import eq as pt_eq
from pytensor.tensor.math import ge, lt, maximum, minimum, prod, switch
from pytensor.tensor.math import (
ge,
gt,
log,
lt,
maximum,
minimum,
prod,
sign,
switch,
)
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_broadcastable
......@@ -1584,27 +1595,346 @@ def broadcast_shape_iter(
return tuple(result_dims)
def geomspace(start, end, steps, base=10.0):
from pytensor.tensor.math import log
def _check_deprecated_inputs(stop, end, num, steps):
if end is not None:
warnings.warn(
"The 'end' parameter is deprecated and will be removed in a future version. Use 'stop' instead.",
DeprecationWarning,
)
stop = end
if steps is not None:
warnings.warn(
"The 'steps' parameter is deprecated and will be removed in a future version. Use 'num' instead.",
DeprecationWarning,
)
num = steps
return stop, num
def _linspace_core(
start: TensorVariable,
stop: TensorVariable,
num: int,
endpoint=True,
retstep=False,
axis=0,
) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
div = (num - 1) if endpoint else num
delta = stop - start
samples = ptb.shape_padright(ptb.arange(0, num), delta.ndim)
step = delta / div
samples = switch(gt(div, 0), samples * delta / div + start, samples * delta + start)
if endpoint:
samples = switch(gt(num, 1), set_subtensor(samples[-1, ...], stop), samples)
if axis != 0:
samples = ptb.moveaxis(samples, 0, axis)
if retstep:
return samples, step
return samples
def _broadcast_base_with_inputs(start, stop, base, axis):
"""
Broadcast the base tensor with the start and stop tensors if base is not a scalar. This is important because it
may change how the axis argument is interpreted in the final output.
Parameters
----------
start: TensorVariable
The start value(s) of the sequence(s).
stop: TensorVariable
The end value(s) of the sequence(s)
base: TensorVariable
The log base value(s) of the sequence(s)
axis: int
The axis along which to generate samples.
Returns
-------
start: TensorVariable
The start value(s) of the sequence(s), broadcast with the base tensor if necessary.
stop: TensorVariable
The end value(s) of the sequence(s), broadcast with the base tensor if necessary.
base: TensorVariable
The log base value(s) of the sequence(s), broadcast with the start and stop tensors if necessary.
"""
base = ptb.as_tensor_variable(base)
if base.ndim > 0:
ndmax = len(broadcast_shape(start, stop, base))
start, stop, base = (
ptb.shape_padleft(a, ndmax - a.ndim) for a in (start, stop, base)
)
base = ptb.expand_dims(base, axis=(axis,))
return start, stop, base
def linspace(
start: TensorLike,
stop: TensorLike,
num: TensorLike = 50,
endpoint: bool = True,
retstep: bool = False,
dtype: str | None = None,
axis: int = 0,
end: TensorLike | None = None,
steps: TensorLike | None = None,
) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
"""
Return evenly spaced numbers over a specified interval.
Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`].
The endpoint of the interval can optionally be excluded.
Parameters
----------
start: int, float, or TensorVariable
The starting value of the sequence.
stop: int, float or TensorVariable
The end value of the sequence, unless `endpoint` is set to False.
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded.
num: int
Number of samples to generate. Must be non-negative.
start = ptb.as_tensor_variable(start)
end = ptb.as_tensor_variable(end)
return base ** linspace(log(start) / log(base), log(end) / log(base), steps)
endpoint: bool
Whether to include the endpoint in the range.
retstep: bool
If true, returns both the samples and an array of steps between samples.
def logspace(start, end, steps, base=10.0):
start = ptb.as_tensor_variable(start)
end = ptb.as_tensor_variable(end)
return base ** linspace(start, end, steps)
dtype: str, optional
dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start`
and `end` arguments.
axis: int
Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0
will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1.
end: int, float or TensorVariable
.. warning::
The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead.
The end value of the sequence, unless `endpoint` is set to False.
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is
excluded.
steps: float, int, or TensorVariable
.. warning::
The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead.
Number of samples to generate. Must be non-negative
Returns
-------
samples: TensorVariable
Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True.
step: TensorVariable
Tensor containing the spacing between samples. Only returned if `retstep` is True.
"""
if dtype is None:
dtype = pytensor.config.floatX
end, num = _check_deprecated_inputs(stop, end, num, steps)
start, stop = broadcast_arrays(start, stop)
ls = _linspace_core(
start=start,
stop=stop,
num=num,
endpoint=endpoint,
retstep=retstep,
axis=axis,
)
return ls.astype(dtype)
def geomspace(
start: TensorLike,
stop: TensorLike,
num: int = 50,
base: float = 10.0,
endpoint: bool = True,
dtype: str | None = None,
axis: int = 0,
end: TensorLike | None = None,
steps: TensorLike | None = None,
) -> TensorVariable:
"""
Return numbers spaced evenly on a log scale (a geometric progression).
This is similar to logspace, but with endpoints specified directly. Each output sample is a constant multiple of
the previous.
Parameters
----------
Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`].
The endpoint of the interval can optionally be excluded.
Parameters
----------
start: int, float, or TensorVariable
The starting value of the sequence.
stop: int, float or TensorVariable
The end value of the sequence, unless `endpoint` is set to False.
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded.
num: int
Number of samples to generate. Must be non-negative.
base: float
The base of the log space.
endpoint: bool
Whether to include the endpoint in the range.
dtype: str, optional
dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start`
and `end` arguments.
axis: int
Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0
will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1.
end: int, float or TensorVariable
.. warning::
The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead.
The end value of the sequence, unless `endpoint` is set to False.
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is
excluded.
steps: float, int, or TensorVariable
.. warning::
The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead.
Number of samples to generate. Must be non-negative
Returns
-------
samples: TensorVariable
Tensor containing `num` evenly-spaced (in log space) values between [start, stop]. The range is inclusive if
`endpoint` is True.
"""
if dtype is None:
dtype = pytensor.config.floatX
stop, num = _check_deprecated_inputs(stop, end, num, steps)
start, stop = broadcast_arrays(start, stop)
start, stop, base = _broadcast_base_with_inputs(start, stop, base, axis)
out_sign = sign(start)
log_start, log_stop = (
log(start * out_sign) / log(base),
log(stop * out_sign) / log(base),
)
result = _linspace_core(
start=log_start,
stop=log_stop,
num=num,
endpoint=endpoint,
axis=0,
retstep=False,
)
result = base**result
result = switch(gt(num, 0), set_subtensor(result[0, ...], start), result)
if endpoint:
result = switch(gt(num, 1), set_subtensor(result[-1, ...], stop), result)
result = result * out_sign
if axis != 0:
result = ptb.moveaxis(result, 0, axis)
return result.astype(dtype)
def logspace(
start: TensorLike,
stop: TensorLike,
num: int = 50,
base: float = 10.0,
endpoint: bool = True,
dtype: str | None = None,
axis: int = 0,
end: TensorLike | None = None,
steps: TensorLike | None = None,
) -> TensorVariable:
"""
Return numbers spaced evenly on a log scale.
In linear space, the sequence starts at ``base ** start`` (base to the power of start) and ends with ``base ** stop``
(see ``endpoint`` below).
Parameters
----------
start: int, float, or TensorVariable
``base ** start`` is the starting value of the sequence
stop: int, float or TensorVariable
``base ** stop`` is the endpoint of the sequence, unless ``endopoint`` is set to False.
In that case, ``num + 1`` values are spaced over the interval in log-space, and the first ``num`` are returned.
num: int, default = 50
Number of samples to generate.
base: float, default = 10.0
The base of the log space. The step size between the elements in ``log(samples) / log(base)``
(or ``log_base(samples)`` is uniform.
endpoint: bool
Whether to include the endpoint in the range.
dtype: str, optional
dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start`
and `stop` arguments.
axis: int
Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0
will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1.
end: int float or TensorVariable
.. warning::
The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead.
The end value of the sequence, unless `endpoint` is set to False.
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is
excluded.
steps: int or TensorVariable
.. warning::
The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead.
Number of samples to generate. Must be non-negative
Returns
-------
samples: TensorVariable
Tensor containing `num` evenly-spaced (in log-pace) values between [start, stop]. The range is inclusive if
`endpoint` is True.
"""
if dtype is None:
dtype = pytensor.config.floatX
stop, num = _check_deprecated_inputs(stop, end, num, steps)
start, stop = broadcast_arrays(start, stop)
start, stop, base = _broadcast_base_with_inputs(start, stop, base, axis)
ls = _linspace_core(
start=start,
stop=stop,
num=num,
endpoint=endpoint,
axis=axis,
retstep=False,
)
def linspace(start, end, steps):
start = ptb.as_tensor_variable(start)
end = ptb.as_tensor_variable(end)
arr = ptb.arange(steps)
arr = ptb.shape_padright(arr, max(start.ndim, end.ndim))
multiplier = (end - start) / (steps - 1)
return start + arr * multiplier
return (base**ls).astype(dtype)
def broadcast_to(
......
from collections.abc import Callable
from functools import partial
from typing import Literal, cast
from pytensor.compile.builders import OpFromGraph
from pytensor.ifelse import ifelse
from pytensor.scan import scan
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import (
TensorVariable,
as_tensor,
concatenate,
expand_dims,
moveaxis,
switch,
zeros,
)
from pytensor.tensor.extra_ops import broadcast_to, linspace
from pytensor.tensor.math import divmod as pt_divmod
from pytensor.tensor.math import eq, gt, mean, minimum
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min
from pytensor.tensor.shape import specify_broadcastable
from pytensor.tensor.subtensor import flip, set_subtensor, slice_at_axis
PadMode = Literal[
"constant",
"edge",
"linear_ramp",
"maximum",
"minimum",
"mean",
"median",
"wrap",
"symmetric",
"reflect",
]
stat_funcs = {"maximum": pt_max, "minimum": pt_min, "mean": mean}
allowed_kwargs = {
"edge": [],
"wrap": [],
"constant": ["constant_values"],
"linear_ramp": ["end_values"],
"maximum": ["stat_length"],
"mean": ["stat_length"],
"median": ["stat_length"],
"minimum": ["stat_length"],
"reflect": ["reflect_type"],
"symmetric": ["reflect_type"],
}
def _get_edges(
padded: TensorVariable, axis: int, width_pair: tuple[TensorVariable, TensorVariable]
) -> tuple[TensorVariable, TensorVariable]:
"""
Retrieve edge values from empty-padded array in given dimension.
Copied from numpy.lib.arraypad._get_edges
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L154
Parameters
----------
padded : TensorVariable
Empty-padded array.
axis : int
Dimension in which the edges are considered.
width_pair : (TensorVariable, TensorVariable)
Pair of widths that mark the pad area on both sides in the given
dimension.
Returns
-------
left_edge, right_edge : TensorVariable
Edge values of the valid area in `padded` in the given dimension. Its
shape will always match `padded` except for the dimension given by
`axis` which will have a length of 1.
"""
left_index = width_pair[0]
left_slice = slice_at_axis(slice(left_index, left_index + 1), axis)
left_edge = padded[left_slice]
right_index = padded.shape[axis] - width_pair[1]
right_slice = slice_at_axis(slice(right_index - 1, right_index), axis)
right_edge = padded[right_slice]
return left_edge, right_edge
def _symbolic_pad(
x: TensorVariable, pad_width: TensorVariable
) -> tuple[TensorVariable, tuple[slice, ...], TensorVariable]:
pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2)))
new_shape = as_tensor(
[pad_width[i][0] + size + pad_width[i][1] for i, size in enumerate(x.shape)]
)
original_area_slice = tuple(
slice(pad_width[i][0], pad_width[i][0] + size) for i, size in enumerate(x.shape)
)
padded: TensorVariable = set_subtensor(zeros(new_shape)[original_area_slice], x)
return padded, original_area_slice, pad_width
def _get_padding_slices(
dim_shape: TensorVariable,
width_pair: tuple[TensorVariable, TensorVariable],
axis: int,
) -> tuple[tuple[slice, ...], tuple[slice, ...]]:
left_slice = slice_at_axis(slice(None, width_pair[0]), axis)
right_slice = slice_at_axis(slice(dim_shape - width_pair[1], None), axis)
return left_slice, right_slice
def _constant_pad(
x: TensorVariable, pad_width: TensorVariable, constant_values: TensorVariable
) -> TensorVariable:
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
values = broadcast_to(constant_values, as_tensor((padded.ndim, 2)))
for axis in range(padded.ndim):
width_pair = pad_width[axis]
value_pair = values[axis]
dim_shape = padded.shape[axis]
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
padded = set_subtensor(padded[left_slice], value_pair[0])
padded = set_subtensor(padded[right_slice], value_pair[1])
return padded
def _edge_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
for axis in range(padded.ndim):
width_pair = pad_width[axis]
dim_shape = padded.shape[axis]
left_edge, right_edge = _get_edges(padded, axis, width_pair)
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
padded = set_subtensor(padded[left_slice], left_edge)
padded = set_subtensor(padded[right_slice], right_edge)
return padded
def _get_stats(
padded: TensorVariable,
axis: int,
width_pair: TensorVariable,
length_pair: tuple[TensorVariable, TensorVariable] | tuple[None, None],
stat_func: Callable,
):
"""
Calculate statistic for the empty-padded array in given dimension.
Copied from numpy.lib.arraypad._get_stats
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L230
Parameters
----------
padded : TensorVariable
Empty-padded array.
axis : int
Dimension in which the statistic is calculated.
width_pair : (TensorVariable, TensorVariable)
Pair of widths that mark the pad area on both sides in the given dimension.
length_pair : 2-element sequence of None or TensorVariable
Gives the number of values in valid area from each side that is taken into account when calculating the
statistic. If None the entire valid area in `padded` is considered.
stat_func : function
Function to compute statistic. The expected signature is
``stat_func(x: TensorVariable, axis: int, keepdims: bool) -> TensorVariable``.
Returns
-------
left_stat, right_stat : TensorVariable
Calculated statistic for both sides of `padded`.
"""
# Calculate indices of the edges of the area with original values
left_index = width_pair[0]
right_index = padded.shape[axis] - width_pair[1]
# as well as its length
max_length = right_index - left_index
# Limit stat_lengths to max_length
left_length, right_length = length_pair
# Calculate statistic for the left side
left_length = (
minimum(left_length, max_length) if left_length is not None else max_length
)
left_slice = slice_at_axis(slice(left_index, left_index + left_length), axis)
left_chunk = padded[left_slice]
left_stat = stat_func(left_chunk, axis=axis, keepdims=True)
if left_length is None and right_length is None:
# We could also return early in the more general case of left_length == right_length, but we don't necessarily
# know these shapes.
# TODO: Add rewrite to simplify in this case
return left_stat, left_stat
# Calculate statistic for the right side
right_length = (
minimum(right_length, max_length) if right_length is not None else max_length
)
right_slice = slice_at_axis(slice(right_index - right_length, right_index), axis)
right_chunk = padded[right_slice]
right_stat = stat_func(right_chunk, axis=axis, keepdims=True)
return left_stat, right_stat
def _stat_pad(
x: TensorVariable,
pad_width: TensorVariable,
stat_func: Callable,
stat_length: TensorVariable | None,
):
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
if stat_length is None:
stat_length = [[None, None]] * padded.ndim # type: ignore
else:
stat_length = broadcast_to(stat_length, as_tensor((padded.ndim, 2)))
for axis in range(padded.ndim):
width_pair = pad_width[axis]
length_pair = stat_length[axis] # type: ignore
dim_shape = padded.shape[axis]
left_stat, right_stat = _get_stats(
padded, axis, width_pair, length_pair, stat_func
)
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
padded = set_subtensor(padded[left_slice], left_stat)
padded = set_subtensor(padded[right_slice], right_stat)
return padded
def _linear_ramp_pad(
x: TensorVariable, pad_width: TensorVariable, end_values: TensorVariable | int = 0
) -> TensorVariable:
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
end_values = as_tensor(end_values)
end_values = broadcast_to(end_values, as_tensor((padded.ndim, 2)))
for axis in range(padded.ndim):
width_pair = pad_width[axis]
end_value_pair = end_values[axis]
edge_pair = _get_edges(padded, axis, width_pair)
dim_shape = padded.shape[axis]
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
left_ramp, right_ramp = (
linspace(
start=end_value,
stop=specify_broadcastable(edge, axis).squeeze(axis),
num=width,
endpoint=False,
dtype=padded.dtype,
axis=axis,
)
for end_value, edge, width in zip(end_value_pair, edge_pair, width_pair)
)
# Reverse the direction of the ramp for the "right" side
right_ramp = right_ramp[slice_at_axis(slice(None, None, -1), axis)] # type: ignore
padded = set_subtensor(padded[left_slice], left_ramp)
padded = set_subtensor(padded[right_slice], right_ramp)
return padded
def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2)))
for axis in range(x.ndim):
size = x.shape[axis]
# Compute how many complete copies of the input will be padded on this dimension, along with the amount of
# overflow on the final copy
repeats, (left_remainder, right_remainder) = pt_divmod(pad_width[axis], size)
# In the next step we will generate extra copies of the input, and then trim them down to the correct size.
left_trim = size - left_remainder
right_trim = size - right_remainder
# The total number of copies needed is always the sum of the number of complete copies to add, plus the original
# input itself, plus the two edge copies that will be trimmed down.
total_repeats = repeats.sum() + 3
# Create a batch dimension and clone the input the required number of times
parts = expand_dims(x, (0,)).repeat(total_repeats, axis=0)
# Move the batch dimension to the active dimension
parts = moveaxis(parts, 0, axis)
# Ravel the active dimension while preserving the shapes of the inactive dimensions. This will expand the
# active dimension to have the correctly padded shape, plus excess to be trimmed
new_shape = [-1 if i == axis else x.shape[i] for i in range(x.ndim)]
x = parts.reshape(new_shape)
# Trim the excess on the active dimension
trim_slice = slice_at_axis(slice(left_trim, -right_trim), axis)
x = x[trim_slice]
return x
def _build_padding_one_direction(array, array_flipped, repeats, *, inner_func, axis):
[_, parts], _ = scan(
inner_func,
non_sequences=[array, array_flipped],
outputs_info=[0, None],
n_steps=repeats,
)
parts = moveaxis(parts, 0, axis)
new_shape = [-1 if i == axis else array.shape[i] for i in range(array.ndim)]
padding = parts.reshape(new_shape)
return padding
def _symmetric_pad(x, pad_width):
def _symmetric_inner(i, x, x_flipped, padding_left):
return i + 1, ifelse(eq(i % 2, int(padding_left)), x_flipped, x)
pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2)))
for axis in range(x.ndim):
x_flipped = flip(x, axis=axis)
original_size = x.shape[axis]
repeats, remainders = pt_divmod(pad_width[axis], original_size)
has_remainder = gt(remainders, 0)
repeats = repeats + has_remainder
left_padding = _build_padding_one_direction(
x,
x_flipped,
repeats[0],
axis=axis,
inner_func=partial(_symmetric_inner, padding_left=True),
)
right_padding = _build_padding_one_direction(
x,
x_flipped,
repeats[1],
axis=axis,
inner_func=partial(_symmetric_inner, padding_left=False),
)
x = concatenate([flip(left_padding, axis), x, right_padding], axis=axis)
(left_trim, right_trim) = switch(
has_remainder, original_size - remainders, remainders
)
right_trim = x.shape[axis] - right_trim
trim_slice = slice_at_axis(slice(left_trim, right_trim), axis)
x = x[trim_slice]
return x
def _reflect_pad(x, pad_width):
def _reflect_inner(i, x, x_flipped, padding_left):
return i + 1, ifelse(eq(i % 2, int(padding_left)), x_flipped, x)
pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2)))
for axis in range(x.ndim):
trimmed_size = x.shape[axis] - 1
trim_slice = slice_at_axis(slice(None, -1), axis)
x_trimmed = x[trim_slice]
x_flipped = flip(x, axis=axis)[trim_slice]
repeats, remainders = pt_divmod(pad_width[axis], trimmed_size)
repeats = repeats + 1
left_padding = _build_padding_one_direction(
x_trimmed,
x_flipped,
repeats[0],
axis=axis,
inner_func=partial(_reflect_inner, padding_left=True),
)
right_padding = _build_padding_one_direction(
x_trimmed,
x_flipped,
repeats[1],
axis=axis,
inner_func=partial(_reflect_inner, padding_left=False),
)
left_trim = slice_at_axis(slice(trimmed_size - remainders[0] - 1, -1), axis)
right_trim = slice_at_axis(
slice(1, right_padding.shape[axis] - trimmed_size + remainders[1] + 1), axis
)
x = concatenate(
[flip(left_padding, axis)[left_trim], x, right_padding[right_trim]],
axis=axis,
)
return x
class Pad(OpFromGraph):
"""
Wrapper Op for Pad graphs
"""
def __init__(
self, inputs, outputs, pad_mode, reflect_type=None, has_stat_length=False
):
self.pad_mode = pad_mode
self.reflect_type = reflect_type
self.has_stat_length = has_stat_length
super().__init__(inputs=inputs, outputs=outputs)
def pad(
x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwargs
) -> TensorVariable:
"""
Pad an array.
Parameters
----------
array : array_like of rank N
The array to pad.
pad_width : sequence, array_like, or int
Number of values padded to the edges of each axis.
``((before_1, after_1), ... (before_N, after_N))`` unique pad widths
for each axis.
``(before, after)`` or ``((before, after),)`` yields same before
and after pad for each axis.
``(pad,)`` or ``int`` is a shortcut for before = after = pad width
for all axes.
mode : str or function, optional
One of the following string values or a user supplied function.
'constant' (default)
Pads with a constant value.
'edge'
Pads with the edge values of array.
'linear_ramp'
Pads with the linear ramp between end_value and the
array edge value.
'maximum'
Pads with the maximum value of all or part of the
vector along each axis.
'mean'
Pads with the mean value of all or part of the
vector along each axis.
'minimum'
Pads with the minimum value of all or part of the
vector along each axis.
'reflect'
Pads with the reflection of the vector mirrored on
the first and last values of the vector along each
axis.
'symmetric'
Pads with the reflection of the vector mirrored
along the edge of the array.
'wrap'
Pads with the wrap of the vector along the axis.
The first values are used to pad the end and the
end values are used to pad the beginning.
stat_length : sequence or int, optional
Used in 'maximum', 'mean', and 'minimum'. Number of
values at edge of each axis used to calculate the statistic value.
``((before_1, after_1), ... (before_N, after_N))`` unique statistic
lengths for each axis.
``(before, after)`` or ``((before, after),)`` yields same before
and after statistic lengths for each axis.
``(stat_length,)`` or ``int`` is a shortcut for
``before = after = statistic`` length for all axes.
Default is ``None``, to use the entire axis.
constant_values : sequence or scalar, optional
Used in 'constant'. The values to set the padded values for each
axis.
``((before_1, after_1), ... (before_N, after_N))`` unique pad constants
for each axis.
``(before, after)`` or ``((before, after),)`` yields same before
and after constants for each axis.
``(constant,)`` or ``constant`` is a shortcut for
``before = after = constant`` for all axes.
Default is 0.
end_values : sequence or scalar, optional
Used in 'linear_ramp'. The values used for the ending value of the
linear_ramp and that will form the edge of the padded array.
``((before_1, after_1), ... (before_N, after_N))`` unique end values
for each axis.
``(before, after)`` or ``((before, after),)`` yields same before
and after end values for each axis.
``(constant,)`` or ``constant`` is a shortcut for
``before = after = constant`` for all axes.
Default is 0.
reflect_type : str, optional
Only 'even' is currently accepted. Used in 'reflect', and 'symmetric'. The 'even' style is the
default with an unaltered reflection around the edge value.
Returns
-------
pad : ndarray
Padded array of rank equal to `array` with shape increased
according to `pad_width`.
Examples
--------
.. testcode::
import pytensor.tensor as pt
a = [1, 2, 3, 4, 5]
print(pt.pad(a, (2, 3), 'constant', constant_values=(4, 6)).eval())
.. testoutput::
[4. 4. 1. 2. 3. 4. 5. 6. 6. 6.]
.. testcode::
print(pt.pad(a, (2, 3), 'edge').eval())
.. testoutput::
[1. 1. 1. 2. 3. 4. 5. 5. 5. 5.]
.. testcode::
print(pt.pad(a, (2, 3), 'linear_ramp', end_values=(5, -4)).eval())
.. testoutput::
[ 5. 3. 1. 2. 3. 4. 5. 2. -1. -4.]
.. testcode::
print(pt.pad(a, (2,), 'maximum').eval())
.. testoutput::
[5. 5. 1. 2. 3. 4. 5. 5. 5.]
.. testcode::
print(pt.pad(a, (2,), 'mean').eval())
.. testoutput::
[3. 3. 1. 2. 3. 4. 5. 3. 3.]
.. testcode::
a = [[1, 2], [3, 4]]
print(pt.pad(a, ((3, 2), (2, 3)), 'minimum').eval())
.. testoutput::
[[1. 1. 1. 2. 1. 1. 1.]
[1. 1. 1. 2. 1. 1. 1.]
[1. 1. 1. 2. 1. 1. 1.]
[1. 1. 1. 2. 1. 1. 1.]
[3. 3. 3. 4. 3. 3. 3.]
[1. 1. 1. 2. 1. 1. 1.]
[1. 1. 1. 2. 1. 1. 1.]]
.. testcode::
a = [1, 2, 3, 4, 5]
print(pt.pad(a, (2, 3), 'reflect').eval())
.. testoutput::
[3 2 1 2 3 4 5 4 3 2]
.. testcode::
print(pt.pad(a, (2, 3), 'symmetric').eval())
.. testoutput::
[2 1 1 2 3 4 5 5 4 3]
.. testcode::
print(pt.pad(a, (2, 3), 'wrap').eval())
.. testoutput::
[4 5 1 2 3 4 5 1 2 3]
"""
if any(value not in allowed_kwargs[mode] for value in kwargs.keys()):
raise ValueError(
f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}"
)
x = as_tensor(x, name="x")
pad_width = as_tensor(pad_width, name="pad_width")
inputs = [x, pad_width]
attrs = {}
if mode == "constant":
constant_values = as_tensor(
kwargs.pop("constant_values", 0), name="constant_values"
)
inputs += [constant_values]
outputs = _constant_pad(x, pad_width, constant_values)
elif mode == "edge":
outputs = _edge_pad(x, pad_width)
elif mode in ["maximum", "minimum", "mean", "median"]:
if mode == "median":
# TODO: Revisit this after we implement a quantile function.
# See https://github.com/pymc-devs/pytensor/issues/53
raise NotImplementedError("Median padding not implemented")
stat_func = cast(Callable, stat_funcs[mode])
stat_length = kwargs.get("stat_length")
if stat_length is not None:
attrs.update({"has_stat_length": True})
stat_length = as_tensor(stat_length, name="stat_length")
inputs += [stat_length]
outputs = _stat_pad(x, pad_width, stat_func, stat_length)
elif mode == "linear_ramp":
end_values = kwargs.pop("end_values", 0)
end_values = as_tensor(end_values)
inputs += [end_values]
outputs = _linear_ramp_pad(x, pad_width, end_values)
elif mode == "wrap":
outputs = _wrap_pad(x, pad_width)
elif mode == "symmetric":
reflect_type = kwargs.pop("reflect_type", "even")
if reflect_type == "odd":
raise NotImplementedError(
"Odd reflection not implemented. If you need this feature, please open an "
"issue at https://github.com/pymc-devs/pytensor/issues"
)
attrs.update({"reflect_type": reflect_type})
outputs = _symmetric_pad(x, pad_width)
elif mode == "reflect":
reflect_type = kwargs.pop("reflect_type", "even")
if reflect_type == "odd":
raise NotImplementedError(
"Odd reflection not implemented. If you need this feature, please open an "
"issue at https://github.com/pymc-devs/pytensor/issues"
)
attrs.update({"reflect_type": reflect_type})
outputs = _reflect_pad(x, pad_width)
else:
raise ValueError(f"Invalid mode: {mode}")
op = Pad(inputs=inputs, outputs=[outputs], pad_mode=mode, **attrs)(*inputs)
return cast(TensorVariable, op)
__all__ = ["pad", "flip"]
......@@ -3013,8 +3013,123 @@ def _get_vector_length_Subtensor(op, var):
raise ValueError(f"Length of {var} cannot be determined")
def slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]:
"""
Construct tuple of slices to slice an array in the given dimension.
Copied from numpy.lib.arraypad._slice_at_axis
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L33
Parameters
----------
sl : slice
The slice for the given dimension.
axis : int
The axis to which `sl` is applied. All other dimensions are left
"unsliced".
Returns
-------
sl : tuple of slices
A tuple with slices matching `shape` in length.
Examples
--------
.. testcode::
import pytensor.tensor as pt
s = pt.slice_at_axis(slice(None, 1), 1)
print(s)
.. testoutput::
(slice(None, None, None), slice(None, 1, None), Ellipsis)
.. testcode::
x = pt.tensor('x', shape=(None, None, None))
x_sliced = x[s]
f = pytensor.function([x], x_sliced)
x = np.arange(27).reshape(3, 3, 3)
print(f(x))
.. testoutput::
[[[ 0. 1. 2.]]
[[ 9. 10. 11.]]
[[18. 19. 20.]]]
"""
if axis >= 0:
return (slice(None),) * axis + (sl,) + (...,) # type: ignore
else:
# If axis = -1 we want zero right padding (and so on), so subtract one
axis = abs(axis) - 1
return (...,) + (sl,) + (slice(None),) * axis # type: ignore
def flip(
arr: TensorVariable, axis: int | tuple[int] | TensorVariable | None = None
) -> TensorVariable:
"""
Reverse the order of elements in an tensor along the given axis.
Parameters
----------
arr: TensorVariable
Input tensor.
axis: int | tuple[int] | TensorVariable, optional
Axis or axes along which to flip over. The default is to flip over all of the axes of the input tensor.
Returns
-------
arr: TensorVariable
A view of `arr` with the entries of axis reversed.
Examples
--------
.. testcode::
import pytensor
import pytensor.tensor as pt
x = pt.tensor('x', shape=(None, None))
x_flipped = pt.flip(x, axis=0)
f = pytensor.function([x], x_flipped)
x = [[1, 2], [3, 4]]
print(f(x))
.. testoutput::
[[3. 4.]
[1. 2.]]
"""
if axis is None:
index = ((slice(None, None, -1)),) * arr.ndim
else:
if isinstance(axis, int):
axis = (axis,)
index = tuple(
[
slice(None, None, -1) if i in axis else slice(None, None, None)
for i in range(arr.ndim)
]
)
return cast(TensorVariable, arr[index])
__all__ = [
"take",
"flip",
"slice_at_axis",
"inc_subtensor",
"set_subtensor",
]
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor.pad import PadMode
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
floatX = config.floatX
RTOL = ATOL = 1e-6 if floatX.endswith("64") else 1e-3
@pytest.mark.parametrize(
"mode, kwargs",
[
("constant", {"constant_values": 0}),
("constant", {"constant_values": (1, 2)}),
("edge", {}),
("linear_ramp", {"end_values": 0}),
("linear_ramp", {"end_values": (1, 2)}),
("reflect", {"reflect_type": "even"}),
("wrap", {}),
("symmetric", {"reflect_type": "even"}),
("mean", {"stat_length": None}),
("mean", {"stat_length": (10, 2)}),
("maximum", {"stat_length": None}),
("maximum", {"stat_length": (10, 2)}),
("minimum", {"stat_length": None}),
("minimum", {"stat_length": (10, 2)}),
],
ids=[
"constant_default",
"constant_tuple",
"edge",
"linear_ramp_default",
"linear_ramp_tuple",
"reflect",
"wrap",
"symmetric",
"mean_default",
"mean_tuple",
"maximum_default",
"maximum_tuple",
"minimum_default",
"minimum_tuple",
],
)
def test_jax_pad(mode: PadMode, kwargs):
x_pt = pt.tensor("x", shape=(3, 3))
x = np.random.normal(size=(3, 3))
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
res_fg = FunctionGraph([x_pt], [res])
compare_jax_and_py(
res_fg,
[x],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
py_mode="FAST_RUN",
)
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor.pad import PadMode
from tests.link.numba.test_basic import compare_numba_and_py
floatX = config.floatX
RTOL = ATOL = 1e-6 if floatX.endswith("64") else 1e-3
@pytest.mark.parametrize(
"mode, kwargs",
[
("constant", {"constant_values": 0}),
("constant", {"constant_values": (1, 2)}),
pytest.param(
"edge",
{},
marks=pytest.mark.skip(
"This is causing a segfault in NUMBA mode, but I have no idea why"
),
),
("linear_ramp", {"end_values": 0}),
("linear_ramp", {"end_values": (1, 2)}),
("reflect", {"reflect_type": "even"}),
("wrap", {}),
("symmetric", {"reflect_type": "even"}),
("mean", {"stat_length": None}),
("mean", {"stat_length": (10, 2)}),
("maximum", {"stat_length": None}),
("maximum", {"stat_length": (10, 2)}),
("minimum", {"stat_length": None}),
("minimum", {"stat_length": (10, 2)}),
],
ids=[
"constant_default",
"constant_tuple",
"edge",
"linear_ramp_default",
"linear_ramp_tuple",
"reflect",
"wrap",
"symmetric",
"mean_default",
"mean_tuple",
"maximum_default",
"maximum_tuple",
"minimum_default",
"minimum_tuple",
],
)
def test_numba_pad(mode: PadMode, kwargs):
x_pt = pt.tensor("x", shape=(3, 3))
x = np.random.normal(size=(3, 3))
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
res_fg = FunctionGraph([x_pt], [res])
compare_numba_and_py(
res_fg,
[x],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
py_mode="FAST_RUN",
)
......@@ -35,9 +35,6 @@ from pytensor.tensor.extra_ops import (
diff,
fill_diagonal,
fill_diagonal_offset,
geomspace,
linspace,
logspace,
ravel_multi_index,
repeat,
searchsorted,
......@@ -1281,25 +1278,37 @@ def test_broadcast_arrays():
@pytest.mark.parametrize(
"start, stop, num_samples",
"op",
["linspace", "logspace", "geomspace"],
ids=["linspace", "logspace", "geomspace"],
)
@pytest.mark.parametrize("dtype", [None, "int", "float"], ids=[None, "int", "float"])
@pytest.mark.parametrize(
"start, stop, num_samples, endpoint, axis",
[
(1, 10, 50),
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25),
(1, np.array([5, 6]), 30),
(1, 10, 50, True, 0),
(1, 10, 1, True, 0),
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 0),
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 1),
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, False, -1),
(1, np.array([5, 6]), 30, True, 0),
(1, np.array([5, 6]), 30, False, -1),
],
)
def test_space_ops(start, stop, num_samples):
z = linspace(start, stop, num_samples)
pytensor_res = function(inputs=[], outputs=z)()
numpy_res = np.linspace(start, stop, num=num_samples)
assert np.allclose(pytensor_res, numpy_res)
z = logspace(start, stop, num_samples)
pytensor_res = function(inputs=[], outputs=z)()
numpy_res = np.logspace(start, stop, num=num_samples)
assert np.allclose(pytensor_res, numpy_res)
z = geomspace(start, stop, num_samples)
pytensor_res = function(inputs=[], outputs=z)()
numpy_res = np.geomspace(start, stop, num=num_samples)
assert np.allclose(pytensor_res, numpy_res)
def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
pt_func = getattr(pt, op)
np_func = getattr(np, op)
dtype = dtype + config.floatX[-2:] if dtype is not None else dtype
z = pt_func(start, stop, num_samples, endpoint=endpoint, axis=axis, dtype=dtype)
numpy_res = np_func(
start, stop, num=num_samples, endpoint=endpoint, dtype=dtype, axis=axis
)
pytensor_res = function(inputs=[], outputs=z, mode="FAST_COMPILE")()
np.testing.assert_allclose(
pytensor_res,
numpy_res,
atol=1e-6 if config.floatX.endswith("64") else 1e-4,
rtol=1e-6 if config.floatX.endswith("64") else 1e-4,
)
from typing import Literal
import numpy as np
import pytest
import pytensor
from pytensor.tensor.pad import PadMode, pad
floatX = pytensor.config.floatX
RTOL = ATOL = 1e-8 if floatX.endswith("64") else 1e-4
def test_unknown_mode_raises():
x = np.random.normal(size=(3, 3)).astype(floatX)
with pytest.raises(ValueError, match="Invalid mode: unknown"):
pad(x, 1, mode="unknown")
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 3, 3)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize("constant", [0, 0.0], ids=["int", "float"])
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
def test_constant_pad(
size: tuple, constant: int | float, pad_width: int | tuple[int, ...]
):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="constant", constant_values=constant)
z = pad(x, pad_width, mode="constant", constant_values=constant)
assert z.owner.op.pad_mode == "constant"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
def test_edge_pad(size: tuple, pad_width: int | tuple[int, ...]):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="edge")
z = pad(x, pad_width, mode="edge")
assert z.owner.op.pad_mode == "edge"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
@pytest.mark.parametrize("end_values", [0, -1], ids=["0", "-1"])
def test_linear_ramp_pad(
size: tuple,
pad_width: int | tuple[int, ...],
end_values: int | float | tuple[int | float, ...],
):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="linear_ramp", end_values=end_values)
z = pad(x, pad_width, mode="linear_ramp", end_values=end_values)
assert z.owner.op.pad_mode == "linear_ramp"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
@pytest.mark.parametrize("stat", ["mean", "minimum", "maximum"])
@pytest.mark.parametrize("stat_length", [None, 2])
def test_stat_pad(
size: tuple,
pad_width: int | tuple[int, ...],
stat: PadMode,
stat_length: int | None,
):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode=stat, stat_length=stat_length)
z = pad(x, pad_width, mode=stat, stat_length=stat_length)
assert z.owner.op.pad_mode == stat
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
def test_wrap_pad(size: tuple, pad_width: int | tuple[int, ...]):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="wrap")
z = pad(x, pad_width, mode="wrap")
assert z.owner.op.pad_mode == "wrap"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
@pytest.mark.parametrize(
"reflect_type",
["even", pytest.param("odd", marks=pytest.mark.xfail(raises=NotImplementedError))],
ids=["even", "odd"],
)
def test_symmetric_pad(
size,
pad_width,
reflect_type: Literal["even", "odd"],
):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="symmetric", reflect_type=reflect_type)
z = pad(x, pad_width, mode="symmetric", reflect_type=reflect_type)
assert z.owner.op.pad_mode == "symmetric"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width",
[10, (10, 0), (0, 10)],
ids=["symmetrical", "asymmetrical_left", "asymmetric_right"],
)
@pytest.mark.parametrize(
"reflect_type",
["even", pytest.param("odd", marks=pytest.mark.xfail(raises=NotImplementedError))],
ids=["even", "odd"],
)
def test_reflect_pad(
size,
pad_width,
reflect_type: Literal["even", "odd"],
):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="reflect", reflect_type=reflect_type)
z = pad(x, pad_width, mode="reflect", reflect_type=reflect_type)
assert z.owner.op.pad_mode == "reflect"
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@pytest.mark.parametrize(
"mode",
[
"constant",
"edge",
"linear_ramp",
"wrap",
"symmetric",
"reflect",
"mean",
"maximum",
"minimum",
],
)
@pytest.mark.parametrize("padding", ["symmetric", "asymmetric"])
def test_nd_padding(mode, padding):
rng = np.random.default_rng()
n = rng.integers(3, 5)
if padding == "symmetric":
pad_width = [(i, i) for i in rng.integers(1, 5, size=n)]
stat_length = [(i, i) for i in rng.integers(1, 5, size=n)]
else:
pad_width = rng.integers(1, 5, size=(n, 2)).tolist()
stat_length = rng.integers(1, 5, size=(n, 2)).tolist()
test_kwargs = {
"constant": {"constant_values": 0},
"linear_ramp": {"end_values": 0},
"maximum": {"stat_length": stat_length},
"mean": {"stat_length": stat_length},
"minimum": {"stat_length": stat_length},
"reflect": {"reflect_type": "even"},
"symmetric": {"reflect_type": "even"},
}
x = np.random.normal(size=(2,) * n).astype(floatX)
kwargs = test_kwargs.get(mode, {})
expected = np.pad(x, pad_width, mode=mode, **kwargs)
z = pad(x, pad_width, mode=mode, **kwargs)
f = pytensor.function([], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
......@@ -37,11 +37,13 @@ from pytensor.tensor.subtensor import (
advanced_subtensor1,
as_index_literal,
basic_shape,
flip,
get_canonical_form_slice,
inc_subtensor,
index_vars_to_types,
indexed_result_shape,
set_subtensor,
slice_at_axis,
take,
)
from pytensor.tensor.type import (
......@@ -2902,3 +2904,39 @@ def test_vectorize_adv_subtensor(
vectorize_pt(x_test, idx_test),
vectorize_np(x_test, idx_test),
)
def test_slice_at_axis():
x = ptb.tensor("x", shape=(3, 4, 5))
x_sliced = x[slice_at_axis(slice(None, 1), axis=0)]
assert x_sliced.type.shape == (1, 4, 5)
# Negative axis
x_sliced = x[slice_at_axis(slice(None, 1), axis=-2)]
assert x_sliced.type.shape == (3, 1, 5)
@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
def test_flip(size: tuple[int]):
from itertools import combinations
ATOL = RTOL = 1e-8 if config.floatX == "float64" else 1e-4
x = np.random.normal(size=size).astype(config.floatX)
x_pt = pytensor.tensor.tensor(shape=size, name="x")
expected = np.flip(x, axis=None)
z = flip(x_pt, axis=None)
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
# Test all combinations of axes
flip_options = [
axes for i in range(1, x.ndim + 1) for axes in combinations(range(x.ndim), r=i)
]
for axes in flip_options:
expected = np.flip(x, axis=list(axes))
z = flip(x_pt, axis=list(axes))
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论