Unverified 提交 10d225dd authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Implement `Convolve2D` Op (#1397)

* Implement simple wrapper Op for `scipy.signal.convolve2d` * Better shape inference * conv2d gradient v0 * Implement Convolve2d and gradients * Add same mode, and FFT support * Relax test tolerance in float32 * Change seed * Relax test tolerance in float32 * Add `convolve2d` to `signal.__init__` --------- Co-authored-by: 's avatarricardoV94 <ricardo.vieira1994@gmail.com>
上级 1d13f8c4
...@@ -921,7 +921,7 @@ def zeros_like(model, dtype=None, opt=False): ...@@ -921,7 +921,7 @@ def zeros_like(model, dtype=None, opt=False):
return fill(_model, ret) return fill(_model, ret)
def zeros(shape, dtype=None): def zeros(shape, dtype=None) -> TensorVariable:
"""Create a `TensorVariable` filled with zeros, closer to NumPy's syntax than ``alloc``.""" """Create a `TensorVariable` filled with zeros, closer to NumPy's syntax than ``alloc``."""
if not ( if not (
isinstance(shape, np.ndarray | Sequence) isinstance(shape, np.ndarray | Sequence)
...@@ -933,7 +933,7 @@ def zeros(shape, dtype=None): ...@@ -933,7 +933,7 @@ def zeros(shape, dtype=None):
return alloc(np.array(0, dtype=dtype), *shape) return alloc(np.array(0, dtype=dtype), *shape)
def ones(shape, dtype=None): def ones(shape, dtype=None) -> TensorVariable:
"""Create a `TensorVariable` filled with ones, closer to NumPy's syntax than ``alloc``.""" """Create a `TensorVariable` filled with ones, closer to NumPy's syntax than ``alloc``."""
if not ( if not (
isinstance(shape, np.ndarray | Sequence) isinstance(shape, np.ndarray | Sequence)
......
from pytensor.tensor.signal.conv import convolve1d from pytensor.tensor.signal.conv import convolve1d, convolve2d
__all__ = ("convolve1d",) __all__ = ("convolve1d", "convolve2d")
from typing import TYPE_CHECKING, Literal, cast from typing import TYPE_CHECKING, Literal
from typing import cast as type_cast
import numpy as np import numpy as np
from numpy import convolve as numpy_convolve from numpy import convolve as numpy_convolve
from scipy.signal import convolve as scipy_convolve
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply, Constant from pytensor.graph import Apply, Constant
from pytensor.graph.op import Op
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.scalar import as_scalar from pytensor.scalar import as_scalar
from pytensor.scalar.basic import upcast from pytensor.scalar.basic import upcast
from pytensor.tensor.basic import as_tensor_variable, join, zeros from pytensor.tensor.basic import as_tensor_variable, join, zeros
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import maximum, minimum, switch from pytensor.tensor.math import maximum, minimum, switch
from pytensor.tensor.type import vector from pytensor.tensor.pad import pad
from pytensor.tensor.subtensor import flip
from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -19,54 +24,69 @@ if TYPE_CHECKING: ...@@ -19,54 +24,69 @@ if TYPE_CHECKING:
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
class Convolve1d(COp): class AbstractConvolveNd:
__props__ = () __props__ = ()
gufunc_signature = "(n),(k),()->(o)" ndim: int
@property
def gufunc_signature(self):
data_signature = ",".join([f"n{i}" for i in range(self.ndim)])
kernel_signature = ",".join([f"k{i}" for i in range(self.ndim)])
output_signature = ",".join([f"o{i}" for i in range(self.ndim)])
return f"({data_signature}),({kernel_signature}),()->({output_signature})"
def make_node(self, in1, in2, full_mode): def make_node(self, in1, in2, full_mode):
in1 = as_tensor_variable(in1) in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2) in2 = as_tensor_variable(in2)
full_mode = as_scalar(full_mode) full_mode = as_scalar(full_mode)
if not (in1.ndim == 1 and in2.ndim == 1): ndim = self.ndim
raise ValueError("Convolution inputs must be vector (ndim=1)") if not (in1.ndim == ndim and in2.ndim == self.ndim):
raise ValueError(
f"Convolution inputs must have ndim={ndim}, got: in1={in1.ndim}, in2={in2.ndim}"
)
if not full_mode.dtype == "bool": if not full_mode.dtype == "bool":
raise ValueError("Convolution mode must be a boolean type") raise ValueError("Convolution full_mode flag must be a boolean type")
dtype = upcast(in1.dtype, in2.dtype)
n = in1.type.shape[0]
k = in2.type.shape[0]
match full_mode: match full_mode:
case Constant(): case Constant():
static_mode = "full" if full_mode.data else "valid" static_mode = "full" if full_mode.data else "valid"
case _: case _:
static_mode = None static_mode = None
if n is None or k is None or static_mode is None: if static_mode is None:
out_shape = (None,) out_shape = (None,) * ndim
elif static_mode == "full": else:
out_shape = (n + k - 1,) out_shape = []
else: # mode == "valid": # TODO: Raise if static shapes are not valid (one input size doesn't dominate the other)
out_shape = (max(n, k) - min(n, k) + 1,) for n, k in zip(in1.type.shape, in2.type.shape):
if n is None or k is None:
out_shape.append(None)
elif static_mode == "full":
out_shape.append(
n + k - 1,
)
else: # mode == "valid":
out_shape.append(
max(n, k) - min(n, k) + 1,
)
out_shape = tuple(out_shape)
out = vector(dtype=dtype, shape=out_shape) dtype = upcast(in1.dtype, in2.dtype)
return Apply(self, [in1, in2, full_mode], [out])
def perform(self, node, inputs, outputs): out = tensor(dtype=dtype, shape=out_shape)
# We use numpy_convolve as that's what scipy would use if method="direct" was passed. return Apply(self, [in1, in2, full_mode], [out])
# And mode != "same", which this Op doesn't cover anyway.
in1, in2, full_mode = inputs
outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid")
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
_, _, full_mode = node.inputs _, _, full_mode = node.inputs
in1_shape, in2_shape, _ = shapes in1_shape, in2_shape, _ = shapes
n = in1_shape[0] out_shape = [
k = in2_shape[0] switch(full_mode, n + k - 1, maximum(n, k) - minimum(n, k) + 1)
shape_valid = maximum(n, k) - minimum(n, k) + 1 for n, k in zip(in1_shape, in2_shape)
shape_full = n + k - 1 ]
shape = switch(full_mode, shape_full, shape_valid)
return [[shape]] return [out_shape]
def connection_pattern(self, node): def connection_pattern(self, node):
return [[True], [True], [False]] return [[True], [True], [False]]
...@@ -75,22 +95,34 @@ class Convolve1d(COp): ...@@ -75,22 +95,34 @@ class Convolve1d(COp):
in1, in2, full_mode = inputs in1, in2, full_mode = inputs
[grad] = output_grads [grad] = output_grads
n = in1.shape[0] n = in1.shape
k = in2.shape[0] k = in2.shape
# Note: this assumes the shape of one input dominates the other over all dimensions (which is required for a valid forward)
# If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve # If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
# The expression below is equivalent to ~(full_mode | (k >= n)) # The expression below is equivalent to ~(full_mode | (k >= n))
full_mode_in1_bar = ~full_mode & (k < n) full_mode_in1_bar = ~full_mode & (k < n).any()
# If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve # If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
# The expression below is equivalent to ~(full_mode | (n >= k)) # The expression below is equivalent to ~(full_mode | (n >= k))
full_mode_in2_bar = ~full_mode & (n < k) full_mode_in2_bar = ~full_mode & (n < k).any()
return [ return [
self(grad, in2[::-1], full_mode_in1_bar), self(grad, flip(in2), full_mode_in1_bar),
self(grad, in1[::-1], full_mode_in2_bar), self(grad, flip(in1), full_mode_in2_bar),
DisconnectedType()(), DisconnectedType()(),
] ]
class Convolve1d(AbstractConvolveNd, COp): # type: ignore[misc]
__props__ = ()
ndim = 1
def perform(self, node, inputs, outputs):
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
# And mode != "same", which this Op doesn't cover anyway.
in1, in2, full_mode = inputs
outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid")
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (2,)
...@@ -210,4 +242,104 @@ def convolve1d( ...@@ -210,4 +242,104 @@ def convolve1d(
mode = "valid" mode = "valid"
full_mode = as_scalar(np.bool_(mode == "full")) full_mode = as_scalar(np.bool_(mode == "full"))
return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode)) return type_cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))
class Convolve2d(AbstractConvolveNd, Op): # type: ignore[misc]
__props__ = ("method",) # type: ignore[assignment]
ndim = 2
def __init__(self, method: Literal["direct", "fft", "auto"] = "auto"):
self.method = method
def perform(self, node, inputs, outputs):
in1, in2, full_mode = inputs
# TODO: Why is .item() needed?
mode: Literal["full", "valid", "same"] = "full" if full_mode.item() else "valid"
outputs[0][0] = scipy_convolve(in1, in2, mode=mode, method=self.method)
def convolve2d(
in1: "TensorLike",
in2: "TensorLike",
mode: Literal["full", "valid", "same"] = "full",
boundary: Literal["fill", "wrap", "symm"] = "fill",
fillvalue: float | int = 0,
method: Literal["direct", "fft", "auto"] = "auto",
) -> TensorVariable:
"""Convolve two two-dimensional arrays.
Convolve in1 and in2, with the output size determined by the mode argument.
Parameters
----------
in1 : (..., N, M) tensor_like
First input.
in2 : (..., K, L) tensor_like
Second input.
mode : {'full', 'valid', 'same'}, optional
A string indicating the size of the output:
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+K-1, M+L-1).
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, K) - min(N, K) + 1, max(M, L) - min(M, L) + 1).
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
boundary : {'fill', 'wrap', 'symm'}, optional
A string indicating how to handle boundaries:
- 'fill': Pads the input arrays with fillvalue.
- 'wrap': Circularly wraps the input arrays.
- 'symm': Symmetrically reflects the input arrays.
fillvalue : float or int, optional
The value to use for padding when boundary is 'fill'. Default is 0.
method : str, one of 'direct', 'fft', or 'auto'
Computation method to use. 'direct' uses direct convolution, 'fft' uses FFT-based convolution,
and 'auto' lets the implementation choose the best method at runtime.
Returns
-------
out: tensor_variable
The discrete linear convolution of in1 with in2.
"""
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)
ndim = max(in1.type.ndim, in2.type.ndim)
def _pad_input(input_tensor, pad_width):
if boundary == "fill":
return pad(
input_tensor,
pad_width=pad_width,
mode="constant",
constant_values=fillvalue,
)
if boundary == "wrap":
return pad(input_tensor, pad_width=pad_width, mode="wrap")
if boundary == "symm":
return pad(input_tensor, pad_width=pad_width, mode="symmetric")
raise ValueError(f"Unsupported boundary mode: {boundary}")
if mode == "same":
# Same mode is implemented as "valid" with a padded input.
pad_width = zeros((ndim, 2), dtype="int64")
pad_width = pad_width[-2, 0].set(in2.shape[-2] // 2)
pad_width = pad_width[-2, 1].set((in2.shape[-2] - 1) // 2)
pad_width = pad_width[-1, 0].set(in2.shape[-1] // 2)
pad_width = pad_width[-1, 1].set((in2.shape[-1] - 1) // 2)
in1 = _pad_input(in1, pad_width)
mode = "valid"
if mode != "valid" and (boundary != "fill" or fillvalue != 0):
# We use a valid convolution on an appropriately padded kernel
*_, k, l = in2.shape
pad_width = zeros((ndim, 2), dtype="int64")
pad_width = pad_width[-2, :].set(k - 1)
pad_width = pad_width[-1, :].set(l - 1)
in1 = _pad_input(in1, pad_width)
mode = "valid"
full_mode = as_scalar(np.bool_(mode == "full"))
return type_cast(
TensorVariable, Blockwise(Convolve2d(method=method))(in1, in2, full_mode)
)
...@@ -3,13 +3,15 @@ from functools import partial ...@@ -3,13 +3,15 @@ from functools import partial
import numpy as np import numpy as np
import pytest import pytest
from scipy.signal import convolve as scipy_convolve from scipy.signal import convolve as scipy_convolve
from scipy.signal import convolve2d as scipy_convolve2d
from pytensor import config, function, grad from pytensor import config, function, grad
from pytensor.graph.rewriting import rewrite_graph from pytensor.graph.rewriting import rewrite_graph
from pytensor.graph.traversal import ancestors, io_toposort from pytensor.graph.traversal import ancestors, io_toposort
from pytensor.tensor import matrix, tensor, vector from pytensor.tensor import matrix, tensor, vector
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Convolve1d, convolve1d from pytensor.tensor.signal.conv import Convolve1d, convolve1d, convolve2d
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -46,7 +48,7 @@ def test_convolve1d_batch(): ...@@ -46,7 +48,7 @@ def test_convolve1d_batch():
res = out.eval({x: x_test, y: y_test}) res = out.eval({x: x_test, y: y_test})
# Second entry of x, y are just y, x respectively, # Second entry of x, y are just y, x respectively,
# so res[0] and res[1] should be identical. # so res[0] and res[1] should be identical.
rtol = 1e-6 if config.floatX == "float32" else 1e-15 rtol = 1e-6 if config.floatX == "float32" else 1e-12
res_np = np.convolve(x_test[0], y_test[0]) res_np = np.convolve(x_test[0], y_test[0])
np.testing.assert_allclose(res[0], res_np, rtol=rtol) np.testing.assert_allclose(res[0], res_np, rtol=rtol)
np.testing.assert_allclose(res[1], res_np, rtol=rtol) np.testing.assert_allclose(res[1], res_np, rtol=rtol)
...@@ -100,11 +102,13 @@ def test_convolve1d_valid_grad(static_shape): ...@@ -100,11 +102,13 @@ def test_convolve1d_valid_grad(static_shape):
"local_useless_unbatched_blockwise", "local_useless_unbatched_blockwise",
), ),
) )
[conv_node] = [ [conv_node] = [
node node
for node in io_toposort([larger, smaller], [grad_out]) for node in io_toposort([larger, smaller], [grad_out])
if isinstance(node.op, Convolve1d) if isinstance(node.op, Convolve1d)
] ]
full_mode = conv_node.inputs[-1] full_mode = conv_node.inputs[-1]
# If shape is static we get constant mode == "valid", otherwise it depends on the input shapes # If shape is static we get constant mode == "valid", otherwise it depends on the input shapes
# ignoring E712 because np.True_ and np.False_ need to be compared with `==` to produce a valid boolean # ignoring E712 because np.True_ and np.False_ need to be compared with `==` to produce a valid boolean
...@@ -137,3 +141,74 @@ def convolve1d_grad_benchmarker(convolve_mode, mode, benchmark): ...@@ -137,3 +141,74 @@ def convolve1d_grad_benchmarker(convolve_mode, mode, benchmark):
@pytest.mark.parametrize("convolve_mode", ["full", "valid"]) @pytest.mark.parametrize("convolve_mode", ["full", "valid"])
def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark): def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark):
convolve1d_grad_benchmarker(convolve_mode, "FAST_RUN", benchmark) convolve1d_grad_benchmarker(convolve_mode, "FAST_RUN", benchmark)
@pytest.mark.parametrize(
"kernel_shape", [(3, 3), (5, 3), (5, 8)], ids=lambda x: f"kernel_shape={x}"
)
@pytest.mark.parametrize(
"data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}"
)
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
@pytest.mark.parametrize(
"boundary, boundary_kwargs",
[
("fill", {"fillvalue": 0}),
("fill", {"fillvalue": 0.5}),
("wrap", {}),
("symm", {}),
],
)
def test_convolve2d(kernel_shape, data_shape, mode, boundary, boundary_kwargs):
data = matrix("data")
kernel = matrix("kernel")
op = partial(convolve2d, mode=mode, boundary=boundary, **boundary_kwargs)
conv_result = op(data, kernel)
fn = function([data, kernel], conv_result)
rng = np.random.default_rng((172, kernel_shape, data_shape, sum(map(ord, mode))))
data_val = rng.normal(size=data_shape).astype(data.dtype)
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype)
np.testing.assert_allclose(
fn(data_val, kernel_val),
scipy_convolve2d(
data_val, kernel_val, mode=mode, boundary=boundary, **boundary_kwargs
),
atol=1e-5 if config.floatX == "float32" else 1e-13,
rtol=1e-5 if config.floatX == "float32" else 1e-13,
)
utt.verify_grad(lambda k: op(data_val, k).sum(), [kernel_val])
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
@pytest.mark.parametrize("method", ["direct", "fft"])
def test_batched_1d_agrees_with_2d_row_filter(mode, method):
data = matrix("data")
kernel_1d = vector("kernel_1d")
kernel_2d = expand_dims(kernel_1d, 0)
output_1d = convolve1d(data, kernel_1d, mode=mode)
output_2d = convolve2d(data, kernel_2d, mode=mode, method=method)
grad_1d = grad(output_1d.sum(), kernel_1d).ravel()
grad_2d = grad(output_1d.sum(), kernel_1d).ravel()
fn = function([data, kernel_1d], [output_1d, output_2d, grad_1d, grad_2d])
data_val = np.random.normal(size=(10, 8)).astype(config.floatX)
kernel_1d_val = np.random.normal(size=(3,)).astype(config.floatX)
forward_1d, forward_2d, backward_1d, backward_2d = fn(data_val, kernel_1d_val)
np.testing.assert_allclose(
forward_1d,
forward_2d,
rtol=1e-5 if config.floatX == "float32" else 1e-13,
)
np.testing.assert_allclose(
backward_1d,
backward_2d,
rtol=1e-5 if config.floatX == "float32" else 1e-13,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论