Unverified 提交 0b56ed9c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Implement batched convolve1d (#1318)

上级 a038c8ee
......@@ -14,6 +14,7 @@ import pytensor.link.jax.dispatch.random
import pytensor.link.jax.dispatch.scalar
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.shape
import pytensor.link.jax.dispatch.signal
import pytensor.link.jax.dispatch.slinalg
import pytensor.link.jax.dispatch.sort
import pytensor.link.jax.dispatch.sparse
......
import pytensor.link.jax.dispatch.signal.conv
import jax
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.signal.conv import Conv1d
@jax_funcify.register(Conv1d)
def jax_funcify_Conv1d(op, node, **kwargs):
mode = op.mode
def conv1d(data, kernel):
return jax.numpy.convolve(data, kernel, mode=mode)
return conv1d
......@@ -9,9 +9,11 @@ import pytensor.link.numba.dispatch.nlinalg
import pytensor.link.numba.dispatch.random
import pytensor.link.numba.dispatch.scan
import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.signal
import pytensor.link.numba.dispatch.slinalg
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.subtensor
import pytensor.link.numba.dispatch.tensor_basic
# isort: on
import pytensor.link.numba.dispatch.signal.conv
import numpy as np
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.signal.conv import Conv1d
@numba_funcify.register(Conv1d)
def numba_funcify_Conv1d(op, node, **kwargs):
mode = op.mode
@numba_njit
def conv1d(data, kernel):
return np.convolve(data, kernel, mode=mode)
return conv1d
......@@ -116,6 +116,7 @@ from pytensor.tensor import (
# isort: off
from pytensor.tensor import linalg
from pytensor.tensor import special
from pytensor.tensor import signal
# For backward compatibility
from pytensor.tensor import nlinalg
......
from pytensor.tensor.signal.conv import convolve1d
__all__ = ("convolve1d",)
from typing import TYPE_CHECKING, Literal, cast
from numpy import convolve as numpy_convolve
from pytensor.graph import Apply, Op
from pytensor.scalar.basic import upcast
from pytensor.tensor.basic import as_tensor_variable, join, zeros
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import maximum, minimum
from pytensor.tensor.type import vector
from pytensor.tensor.variable import TensorVariable
if TYPE_CHECKING:
from pytensor.tensor import TensorLike
class Conv1d(Op):
__props__ = ("mode",)
gufunc_signature = "(n),(k)->(o)"
def __init__(self, mode: Literal["full", "valid"] = "full"):
if mode not in ("full", "valid"):
raise ValueError(f"Invalid mode: {mode}")
self.mode = mode
def make_node(self, in1, in2):
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)
assert in1.ndim == 1
assert in2.ndim == 1
dtype = upcast(in1.dtype, in2.dtype)
n = in1.type.shape[0]
k = in2.type.shape[0]
if n is None or k is None:
out_shape = (None,)
elif self.mode == "full":
out_shape = (n + k - 1,)
else: # mode == "valid":
out_shape = (max(n, k) - min(n, k) + 1,)
out = vector(dtype=dtype, shape=out_shape)
return Apply(self, [in1, in2], [out])
def perform(self, node, inputs, outputs):
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
# And mode != "same", which this Op doesn't cover anyway.
outputs[0][0] = numpy_convolve(*inputs, mode=self.mode)
def infer_shape(self, fgraph, node, shapes):
in1_shape, in2_shape = shapes
n = in1_shape[0]
k = in2_shape[0]
if self.mode == "full":
shape = n + k - 1
else: # mode == "valid":
shape = maximum(n, k) - minimum(n, k) + 1
return [[shape]]
def L_op(self, inputs, outputs, output_grads):
in1, in2 = inputs
[grad] = output_grads
if self.mode == "full":
valid_conv = type(self)(mode="valid")
in1_bar = valid_conv(grad, in2[::-1])
in2_bar = valid_conv(grad, in1[::-1])
else: # mode == "valid":
full_conv = type(self)(mode="full")
n = in1.shape[0]
k = in2.shape[0]
kmn = maximum(0, k - n)
nkm = maximum(0, n - k)
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
in1_bar = full_conv(grad, in2[::-1])
in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn]
in2_bar = full_conv(grad, in1[::-1])
in2_bar = in2_bar[nkm : in2_bar.shape[0] - nkm]
return [in1_bar, in2_bar]
def convolve1d(
in1: "TensorLike",
in2: "TensorLike",
mode: Literal["full", "valid", "same"] = "full",
) -> TensorVariable:
"""Convolve two one-dimensional arrays.
Convolve in1 and in2, with the output size determined by the mode argument.
Parameters
----------
in1 : (..., N,) tensor_like
First input.
in2 : (..., M,) tensor_like
Second input.
mode : {'full', 'valid', 'same'}, optional
A string indicating the size of the output:
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+M-1,).
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, M) - min(N, M) + 1,).
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
Returns
-------
out: tensor_variable
The discrete linear convolution of in1 with in2.
"""
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)
if mode == "same":
# We implement "same" as "valid" with padded `in1`.
in1_batch_shape = tuple(in1.shape)[:-1]
zeros_left = in2.shape[0] // 2
zeros_right = (in2.shape[0] - 1) // 2
in1 = join(
-1,
zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype),
in1,
zeros((*in1_batch_shape, zeros_right), dtype=in2.dtype),
)
mode = "valid"
return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2))
import numpy as np
import pytest
from pytensor.tensor import dmatrix
from pytensor.tensor.signal import convolve1d
from tests.link.jax.test_basic import compare_jax_and_py
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
def test_convolve1d(mode):
x = dmatrix("x")
y = dmatrix("y")
out = convolve1d(x[None], y[:, None], mode=mode)
rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5))
test_y = rng.normal(size=(7, 11))
compare_jax_and_py([x, y], out, [test_x, test_y])
import numpy as np
import pytest
from pytensor.tensor import dmatrix
from pytensor.tensor.signal import convolve1d
from tests.link.numba.test_basic import compare_numba_and_py
pytestmark = pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
def test_convolve1d(mode):
x = dmatrix("x")
y = dmatrix("y")
out = convolve1d(x[None], y[:, None], mode=mode)
rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5))
test_y = rng.normal(size=(7, 11))
# Blockwise dispatch for numba can't be run on object mode
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)
from functools import partial
import numpy as np
import pytest
from scipy.signal import convolve as scipy_convolve
from pytensor import config, function
from pytensor.tensor import matrix, vector
from pytensor.tensor.signal.conv import convolve1d
from tests import unittest_tools as utt
@pytest.mark.parametrize("kernel_shape", [3, 5, 8], ids=lambda x: f"kernel_shape={x}")
@pytest.mark.parametrize("data_shape", [3, 5, 8], ids=lambda x: f"data_shape={x}")
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
def test_convolve1d(mode, data_shape, kernel_shape):
data = vector("data")
kernel = vector("kernel")
op = partial(convolve1d, mode=mode)
rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode))))
data_val = rng.normal(size=data_shape).astype(data.dtype)
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype)
fn = function([data, kernel], op(data, kernel))
np.testing.assert_allclose(
fn(data_val, kernel_val),
scipy_convolve(data_val, kernel_val, mode=mode),
rtol=1e-6 if config.floatX == "float32" else 1e-15,
)
utt.verify_grad(op=lambda x: op(x, kernel_val), pt=[data_val])
def test_convolve1d_batch():
x = matrix("data")
y = matrix("kernel")
out = convolve1d(x, y)
rng = np.random.default_rng(38)
x_test = rng.normal(size=(2, 8)).astype(x.dtype)
y_test = x_test[::-1]
res = out.eval({x: x_test, y: y_test})
# Second entry of x, y are just y, x respectively,
# so res[0] and res[1] should be identical.
rtol = 1e-6 if config.floatX == "float32" else 1e-15
res_np = np.convolve(x_test[0], y_test[0])
np.testing.assert_allclose(res[0], res_np, rtol=rtol)
np.testing.assert_allclose(res[1], res_np, rtol=rtol)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论