Unverified 提交 0b56ed9c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Implement batched convolve1d (#1318)

上级 a038c8ee
...@@ -14,6 +14,7 @@ import pytensor.link.jax.dispatch.random ...@@ -14,6 +14,7 @@ import pytensor.link.jax.dispatch.random
import pytensor.link.jax.dispatch.scalar import pytensor.link.jax.dispatch.scalar
import pytensor.link.jax.dispatch.scan import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.shape import pytensor.link.jax.dispatch.shape
import pytensor.link.jax.dispatch.signal
import pytensor.link.jax.dispatch.slinalg import pytensor.link.jax.dispatch.slinalg
import pytensor.link.jax.dispatch.sort import pytensor.link.jax.dispatch.sort
import pytensor.link.jax.dispatch.sparse import pytensor.link.jax.dispatch.sparse
......
import pytensor.link.jax.dispatch.signal.conv
import jax
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.signal.conv import Conv1d
@jax_funcify.register(Conv1d)
def jax_funcify_Conv1d(op, node, **kwargs):
mode = op.mode
def conv1d(data, kernel):
return jax.numpy.convolve(data, kernel, mode=mode)
return conv1d
...@@ -9,9 +9,11 @@ import pytensor.link.numba.dispatch.nlinalg ...@@ -9,9 +9,11 @@ import pytensor.link.numba.dispatch.nlinalg
import pytensor.link.numba.dispatch.random import pytensor.link.numba.dispatch.random
import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.scan
import pytensor.link.numba.dispatch.scalar import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.signal
import pytensor.link.numba.dispatch.slinalg import pytensor.link.numba.dispatch.slinalg
import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.subtensor import pytensor.link.numba.dispatch.subtensor
import pytensor.link.numba.dispatch.tensor_basic import pytensor.link.numba.dispatch.tensor_basic
# isort: on # isort: on
import pytensor.link.numba.dispatch.signal.conv
import numpy as np
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.signal.conv import Conv1d
@numba_funcify.register(Conv1d)
def numba_funcify_Conv1d(op, node, **kwargs):
mode = op.mode
@numba_njit
def conv1d(data, kernel):
return np.convolve(data, kernel, mode=mode)
return conv1d
...@@ -116,6 +116,7 @@ from pytensor.tensor import ( ...@@ -116,6 +116,7 @@ from pytensor.tensor import (
# isort: off # isort: off
from pytensor.tensor import linalg from pytensor.tensor import linalg
from pytensor.tensor import special from pytensor.tensor import special
from pytensor.tensor import signal
# For backward compatibility # For backward compatibility
from pytensor.tensor import nlinalg from pytensor.tensor import nlinalg
......
from pytensor.tensor.signal.conv import convolve1d
__all__ = ("convolve1d",)
from typing import TYPE_CHECKING, Literal, cast
from numpy import convolve as numpy_convolve
from pytensor.graph import Apply, Op
from pytensor.scalar.basic import upcast
from pytensor.tensor.basic import as_tensor_variable, join, zeros
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import maximum, minimum
from pytensor.tensor.type import vector
from pytensor.tensor.variable import TensorVariable
if TYPE_CHECKING:
from pytensor.tensor import TensorLike
class Conv1d(Op):
__props__ = ("mode",)
gufunc_signature = "(n),(k)->(o)"
def __init__(self, mode: Literal["full", "valid"] = "full"):
if mode not in ("full", "valid"):
raise ValueError(f"Invalid mode: {mode}")
self.mode = mode
def make_node(self, in1, in2):
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)
assert in1.ndim == 1
assert in2.ndim == 1
dtype = upcast(in1.dtype, in2.dtype)
n = in1.type.shape[0]
k = in2.type.shape[0]
if n is None or k is None:
out_shape = (None,)
elif self.mode == "full":
out_shape = (n + k - 1,)
else: # mode == "valid":
out_shape = (max(n, k) - min(n, k) + 1,)
out = vector(dtype=dtype, shape=out_shape)
return Apply(self, [in1, in2], [out])
def perform(self, node, inputs, outputs):
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
# And mode != "same", which this Op doesn't cover anyway.
outputs[0][0] = numpy_convolve(*inputs, mode=self.mode)
def infer_shape(self, fgraph, node, shapes):
in1_shape, in2_shape = shapes
n = in1_shape[0]
k = in2_shape[0]
if self.mode == "full":
shape = n + k - 1
else: # mode == "valid":
shape = maximum(n, k) - minimum(n, k) + 1
return [[shape]]
def L_op(self, inputs, outputs, output_grads):
in1, in2 = inputs
[grad] = output_grads
if self.mode == "full":
valid_conv = type(self)(mode="valid")
in1_bar = valid_conv(grad, in2[::-1])
in2_bar = valid_conv(grad, in1[::-1])
else: # mode == "valid":
full_conv = type(self)(mode="full")
n = in1.shape[0]
k = in2.shape[0]
kmn = maximum(0, k - n)
nkm = maximum(0, n - k)
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
in1_bar = full_conv(grad, in2[::-1])
in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn]
in2_bar = full_conv(grad, in1[::-1])
in2_bar = in2_bar[nkm : in2_bar.shape[0] - nkm]
return [in1_bar, in2_bar]
def convolve1d(
in1: "TensorLike",
in2: "TensorLike",
mode: Literal["full", "valid", "same"] = "full",
) -> TensorVariable:
"""Convolve two one-dimensional arrays.
Convolve in1 and in2, with the output size determined by the mode argument.
Parameters
----------
in1 : (..., N,) tensor_like
First input.
in2 : (..., M,) tensor_like
Second input.
mode : {'full', 'valid', 'same'}, optional
A string indicating the size of the output:
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+M-1,).
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, M) - min(N, M) + 1,).
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
Returns
-------
out: tensor_variable
The discrete linear convolution of in1 with in2.
"""
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)
if mode == "same":
# We implement "same" as "valid" with padded `in1`.
in1_batch_shape = tuple(in1.shape)[:-1]
zeros_left = in2.shape[0] // 2
zeros_right = (in2.shape[0] - 1) // 2
in1 = join(
-1,
zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype),
in1,
zeros((*in1_batch_shape, zeros_right), dtype=in2.dtype),
)
mode = "valid"
return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2))
import numpy as np
import pytest
from pytensor.tensor import dmatrix
from pytensor.tensor.signal import convolve1d
from tests.link.jax.test_basic import compare_jax_and_py
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
def test_convolve1d(mode):
x = dmatrix("x")
y = dmatrix("y")
out = convolve1d(x[None], y[:, None], mode=mode)
rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5))
test_y = rng.normal(size=(7, 11))
compare_jax_and_py([x, y], out, [test_x, test_y])
import numpy as np
import pytest
from pytensor.tensor import dmatrix
from pytensor.tensor.signal import convolve1d
from tests.link.numba.test_basic import compare_numba_and_py
pytestmark = pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
def test_convolve1d(mode):
x = dmatrix("x")
y = dmatrix("y")
out = convolve1d(x[None], y[:, None], mode=mode)
rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5))
test_y = rng.normal(size=(7, 11))
# Blockwise dispatch for numba can't be run on object mode
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)
from functools import partial
import numpy as np
import pytest
from scipy.signal import convolve as scipy_convolve
from pytensor import config, function
from pytensor.tensor import matrix, vector
from pytensor.tensor.signal.conv import convolve1d
from tests import unittest_tools as utt
@pytest.mark.parametrize("kernel_shape", [3, 5, 8], ids=lambda x: f"kernel_shape={x}")
@pytest.mark.parametrize("data_shape", [3, 5, 8], ids=lambda x: f"data_shape={x}")
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
def test_convolve1d(mode, data_shape, kernel_shape):
data = vector("data")
kernel = vector("kernel")
op = partial(convolve1d, mode=mode)
rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode))))
data_val = rng.normal(size=data_shape).astype(data.dtype)
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype)
fn = function([data, kernel], op(data, kernel))
np.testing.assert_allclose(
fn(data_val, kernel_val),
scipy_convolve(data_val, kernel_val, mode=mode),
rtol=1e-6 if config.floatX == "float32" else 1e-15,
)
utt.verify_grad(op=lambda x: op(x, kernel_val), pt=[data_val])
def test_convolve1d_batch():
x = matrix("data")
y = matrix("kernel")
out = convolve1d(x, y)
rng = np.random.default_rng(38)
x_test = rng.normal(size=(2, 8)).astype(x.dtype)
y_test = x_test[::-1]
res = out.eval({x: x_test, y: y_test})
# Second entry of x, y are just y, x respectively,
# so res[0] and res[1] should be identical.
rtol = 1e-6 if config.floatX == "float32" else 1e-15
res_np = np.convolve(x_test[0], y_test[0])
np.testing.assert_allclose(res[0], res_np, rtol=rtol)
np.testing.assert_allclose(res[1], res_np, rtol=rtol)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论