Unverified 提交 4e85676a authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Pytensor-native interpolation functions (#1141)

* add interpolate.py * Add jax dispatch for `searchsorted` * Import user-facing functions in `tensor.__init__`
上级 83c6b44c
...@@ -10,6 +10,7 @@ from pytensor.tensor.extra_ops import ( ...@@ -10,6 +10,7 @@ from pytensor.tensor.extra_ops import (
FillDiagonalOffset, FillDiagonalOffset,
RavelMultiIndex, RavelMultiIndex,
Repeat, Repeat,
SearchsortedOp,
Unique, Unique,
UnravelIndex, UnravelIndex,
) )
...@@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs): ...@@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs):
# return filldiagonaloffset # return filldiagonaloffset
raise NotImplementedError("flatiter not implemented in JAX") raise NotImplementedError("flatiter not implemented in JAX")
@jax_funcify.register(SearchsortedOp)
def jax_funcify_SearchsortedOp(op, **kwargs):
side = op.side
def searchsorted(a, v, side=side, sorter=None):
return jnp.searchsorted(a=a, v=v, side=side, sorter=sorter)
return searchsorted
...@@ -128,6 +128,7 @@ from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi ...@@ -128,6 +128,7 @@ from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi
from pytensor.tensor.basic import * from pytensor.tensor.basic import *
from pytensor.tensor.blas import batched_dot, batched_tensordot from pytensor.tensor.blas import batched_dot, batched_tensordot
from pytensor.tensor.extra_ops import * from pytensor.tensor.extra_ops import *
from pytensor.tensor.interpolate import interp, interpolate1d
from pytensor.tensor.io import * from pytensor.tensor.io import *
from pytensor.tensor.math import * from pytensor.tensor.math import *
from pytensor.tensor.pad import pad from pytensor.tensor.pad import pad
......
from collections.abc import Callable
from difflib import get_close_matches
from typing import Literal, get_args
from pytensor import Variable
from pytensor.tensor.basic import as_tensor_variable, switch
from pytensor.tensor.extra_ops import searchsorted
from pytensor.tensor.functional import vectorize
from pytensor.tensor.math import clip, eq, le
from pytensor.tensor.sort import argsort
InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"]
valid_methods = get_args(InterpolationMethod)
def pad_or_return(x, idx, output, left_pad, right_pad, extrapolate):
if extrapolate:
return output
n = x.shape[0]
return switch(eq(idx, 0), left_pad, switch(eq(idx, n), right_pad, output))
def _linear_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
clip_idx = clip(idx, 1, x.shape[0] - 1)
slope = (x_hat - x[clip_idx - 1]) / (x[clip_idx] - x[clip_idx - 1])
y_hat = y[clip_idx - 1] + slope * (y[clip_idx] - y[clip_idx - 1])
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)
def _nearest_neighbor_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
clip_idx = clip(idx, 1, x.shape[0] - 1)
left_distance = x_hat - x[clip_idx - 1]
right_distance = x[clip_idx] - x_hat
y_hat = switch(le(left_distance, right_distance), y[clip_idx - 1], y[clip_idx])
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)
def _stepwise_first_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
clip_idx = clip(idx - 1, 0, x.shape[0] - 1)
y_hat = y[clip_idx]
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)
def _stepwise_last_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
clip_idx = clip(idx, 0, x.shape[0] - 1)
y_hat = y[clip_idx]
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)
def _stepwise_mean_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
clip_idx = clip(idx, 1, x.shape[0] - 1)
y_hat = (y[clip_idx - 1] + y[clip_idx]) / 2
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)
def interpolate1d(
x: Variable,
y: Variable,
method: InterpolationMethod = "linear",
left_pad: Variable | None = None,
right_pad: Variable | None = None,
extrapolate: bool = True,
) -> Callable[[Variable], Variable]:
"""
Create a function to interpolate one-dimensional data.
Parameters
----------
x : TensorLike
Input data used to create an interpolation function. Data will be sorted to be monotonically increasing.
y: TensorLike
Output data used to create an interpolation function. Must have the same shape as `x`.
method : InterpolationMethod, optional
Method for interpolation. The following methods are available:
- 'linear': Linear interpolation
- 'nearest': Nearest neighbor interpolation
- 'first': Stepwise interpolation using the closest value to the left of the query point
- 'last': Stepwise interpolation using the closest value to the right of the query point
- 'mean': Stepwise interpolation using the mean of the two closest values to the query point
left_pad: TensorLike, optional
Value to return inputs `x_hat < x[0]`. Default is `y[0]`. Ignored if ``extrapolate == True``; in this
case, values `x_hat < x[0]` will be extrapolated from the endpoints of `x` and `y`.
right_pad: TensorLike, optional
Value to return for inputs `x_hat > x[-1]`. Default is `y[-1]`. Ignored if ``extrapolate == True``; in this
case, values `x_hat > x[-1]` will be extrapolated from the endpoints of `x` and `y`.
extrapolate: bool
Whether to extend the request interpolation function beyond the range of the input-output pairs specified in
`x` and `y.` If False, constant values will be returned for such inputs.
Returns
-------
interpolation_func: OpFromGraph
A function that can be used to interpolate new data. The function takes a single input `x_hat` and returns
the interpolated value `y_hat`. The input `x_hat` must be a 1d array.
"""
x = as_tensor_variable(x)
y = as_tensor_variable(y)
sort_idx = argsort(x)
x = x[sort_idx]
y = y[sort_idx]
if left_pad is None:
left_pad = y[0] # type: ignore
else:
left_pad = as_tensor_variable(left_pad)
if right_pad is None:
right_pad = y[-1] # type: ignore
else:
right_pad = as_tensor_variable(right_pad)
def _scalar_interpolate1d(x_hat):
idx = searchsorted(x, x_hat)
if x.ndim != 1 or y.ndim != 1:
raise ValueError("Inputs must be 1d")
if method == "linear":
y_hat = _linear_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "nearest":
y_hat = _nearest_neighbor_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "first":
y_hat = _stepwise_first_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "mean":
y_hat = _stepwise_mean_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "last":
y_hat = _stepwise_last_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
else:
raise NotImplementedError(
f"Unknown interpolation method: {method}. "
f"Did you mean {get_close_matches(method, valid_methods)}?"
)
return y_hat
return vectorize(_scalar_interpolate1d, signature="()->()")
def interp(x, xp, fp, left=None, right=None, period=None):
"""
One-dimensional linear interpolation. Similar to ``pytensor.interpolate.interpolate1d``, but with a signature that
matches ``np.interp``
Parameters
----------
x : TensorLike
The x-coordinates at which to evaluate the interpolated values.
xp : TensorLike
The x-coordinates of the data points, must be increasing if argument `period` is not specified. Otherwise,
`xp` is internally sorted after normalizing the periodic boundaries with ``xp = xp % period``.
fp : TensorLike
The y-coordinates of the data points, same length as `xp`.
left : float, optional
Value to return for `x < xp[0]`. Default is `fp[0]`.
right : float, optional
Value to return for `x > xp[-1]`. Default is `fp[-1]`.
period : None
Not supported. Included to ensure the signature of this function matches ``numpy.interp``.
Returns
-------
y : Variable
The interpolated values, same shape as `x`.
"""
xp = as_tensor_variable(xp)
fp = as_tensor_variable(fp)
x = as_tensor_variable(x)
f = interpolate1d(
xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False
)
return f(x)
...@@ -6,6 +6,7 @@ from pytensor.configdefaults import config ...@@ -6,6 +6,7 @@ from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as pt_extra_ops from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.sort import argsort
from pytensor.tensor.type import matrix, tensor from pytensor.tensor.type import matrix, tensor
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -55,6 +56,13 @@ def test_extra_ops(): ...@@ -55,6 +56,13 @@ def test_extra_ops():
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
) )
v = ptb.as_tensor_variable(6.0)
sorted_idx = argsort(a.ravel())
out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_bartlett_dynamic_shape(): def test_bartlett_dynamic_shape():
......
import numpy as np
import pytest
from numpy.testing import assert_allclose
import pytensor
import pytensor.tensor as pt
from pytensor.tensor.interpolate import (
InterpolationMethod,
interp,
interpolate1d,
valid_methods,
)
floatX = pytensor.config.floatX
def test_interp():
xp = [1.0, 2.0, 3.0]
fp = [3.0, 2.0, 0.0]
x = [0, 1, 1.5, 2.72, 3.14]
out = interp(x, xp, fp).eval()
np_out = np.interp(x, xp, fp)
assert_allclose(out, np_out)
def test_interp_padded():
xp = [1.0, 2.0, 3.0]
fp = [3.0, 2.0, 0.0]
assert interp(3.14, xp, fp, right=-99.0).eval() == -99.0
assert_allclose(
interp([-1.0, -2.0, -3.0], xp, fp, left=1000.0).eval(), [1000.0, 1000.0, 1000.0]
)
assert_allclose(
interp([-1.0, 10.0], xp, fp, left=-10, right=10).eval(), [-10, 10.0]
)
@pytest.mark.parametrize("method", valid_methods, ids=str)
@pytest.mark.parametrize(
"left_pad, right_pad", [(None, None), (None, 100), (-100, None), (-100, 100)]
)
def test_interpolate_scalar_no_extrapolate(
method: InterpolationMethod, left_pad, right_pad
):
x = np.linspace(-2, 6, 10)
y = np.sin(x)
f_op = interpolate1d(
x, y, method, extrapolate=False, left_pad=left_pad, right_pad=right_pad
)
x_hat_pt = pt.dscalar("x_hat")
f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN")
# Data points should be returned exactly, except when method == mean
if method not in ["mean", "first"]:
assert f(x[3]) == y[3]
elif method == "first":
assert f(x[3]) == y[2]
else:
# method == 'mean
assert f(x[3]) == (y[2] + y[3]) / 2
# When extrapolate=False, points beyond the data envelope should be constant
left_pad = y[0] if left_pad is None else left_pad
right_pad = y[-1] if right_pad is None else right_pad
assert f(-10) == left_pad
assert f(100) == right_pad
@pytest.mark.parametrize("method", valid_methods, ids=str)
def test_interpolate_scalar_extrapolate(method: InterpolationMethod):
x = np.linspace(-2, 6, 10)
y = np.sin(x)
f_op = interpolate1d(x, y, method)
x_hat_pt = pt.dscalar("x_hat")
f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN")
left_test_point = -5
right_test_point = 100
if method == "linear":
# Linear will compute a slope from the endpoints and continue it
left_slope = (left_test_point - x[0]) / (x[1] - x[0])
right_slope = (right_test_point - x[-2]) / (x[-1] - x[-2])
assert f(left_test_point) == y[0] + left_slope * (y[1] - y[0])
assert f(right_test_point) == y[-2] + right_slope * (y[-1] - y[-2])
elif method == "mean":
left_expected = (y[0] + y[1]) / 2
right_expected = (y[-1] + y[-2]) / 2
assert f(left_test_point) == left_expected
assert f(right_test_point) == right_expected
else:
assert f(left_test_point) == y[0]
assert f(right_test_point) == y[-1]
# For interior points, "first" and "last" should disagree. First should take the left side of the interval,
# and last should take the right.
interior_point = x[3] + 0.1
assert f(interior_point) == (y[4] if method == "last" else y[3])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论