Unverified 提交 10d225dd authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Implement `Convolve2D` Op (#1397)

* Implement simple wrapper Op for `scipy.signal.convolve2d` * Better shape inference * conv2d gradient v0 * Implement Convolve2d and gradients * Add same mode, and FFT support * Relax test tolerance in float32 * Change seed * Relax test tolerance in float32 * Add `convolve2d` to `signal.__init__` --------- Co-authored-by: 's avatarricardoV94 <ricardo.vieira1994@gmail.com>
上级 1d13f8c4
......@@ -921,7 +921,7 @@ def zeros_like(model, dtype=None, opt=False):
return fill(_model, ret)
def zeros(shape, dtype=None):
def zeros(shape, dtype=None) -> TensorVariable:
"""Create a `TensorVariable` filled with zeros, closer to NumPy's syntax than ``alloc``."""
if not (
isinstance(shape, np.ndarray | Sequence)
......@@ -933,7 +933,7 @@ def zeros(shape, dtype=None):
return alloc(np.array(0, dtype=dtype), *shape)
def ones(shape, dtype=None):
def ones(shape, dtype=None) -> TensorVariable:
"""Create a `TensorVariable` filled with ones, closer to NumPy's syntax than ``alloc``."""
if not (
isinstance(shape, np.ndarray | Sequence)
......
from pytensor.tensor.signal.conv import convolve1d
from pytensor.tensor.signal.conv import convolve1d, convolve2d
__all__ = ("convolve1d",)
__all__ = ("convolve1d", "convolve2d")
......@@ -3,13 +3,15 @@ from functools import partial
import numpy as np
import pytest
from scipy.signal import convolve as scipy_convolve
from scipy.signal import convolve2d as scipy_convolve2d
from pytensor import config, function, grad
from pytensor.graph.rewriting import rewrite_graph
from pytensor.graph.traversal import ancestors, io_toposort
from pytensor.tensor import matrix, tensor, vector
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Convolve1d, convolve1d
from pytensor.tensor.signal.conv import Convolve1d, convolve1d, convolve2d
from tests import unittest_tools as utt
......@@ -46,7 +48,7 @@ def test_convolve1d_batch():
res = out.eval({x: x_test, y: y_test})
# Second entry of x, y are just y, x respectively,
# so res[0] and res[1] should be identical.
rtol = 1e-6 if config.floatX == "float32" else 1e-15
rtol = 1e-6 if config.floatX == "float32" else 1e-12
res_np = np.convolve(x_test[0], y_test[0])
np.testing.assert_allclose(res[0], res_np, rtol=rtol)
np.testing.assert_allclose(res[1], res_np, rtol=rtol)
......@@ -100,11 +102,13 @@ def test_convolve1d_valid_grad(static_shape):
"local_useless_unbatched_blockwise",
),
)
[conv_node] = [
node
for node in io_toposort([larger, smaller], [grad_out])
if isinstance(node.op, Convolve1d)
]
full_mode = conv_node.inputs[-1]
# If shape is static we get constant mode == "valid", otherwise it depends on the input shapes
# ignoring E712 because np.True_ and np.False_ need to be compared with `==` to produce a valid boolean
......@@ -137,3 +141,74 @@ def convolve1d_grad_benchmarker(convolve_mode, mode, benchmark):
@pytest.mark.parametrize("convolve_mode", ["full", "valid"])
def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark):
convolve1d_grad_benchmarker(convolve_mode, "FAST_RUN", benchmark)
@pytest.mark.parametrize(
"kernel_shape", [(3, 3), (5, 3), (5, 8)], ids=lambda x: f"kernel_shape={x}"
)
@pytest.mark.parametrize(
"data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}"
)
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
@pytest.mark.parametrize(
"boundary, boundary_kwargs",
[
("fill", {"fillvalue": 0}),
("fill", {"fillvalue": 0.5}),
("wrap", {}),
("symm", {}),
],
)
def test_convolve2d(kernel_shape, data_shape, mode, boundary, boundary_kwargs):
data = matrix("data")
kernel = matrix("kernel")
op = partial(convolve2d, mode=mode, boundary=boundary, **boundary_kwargs)
conv_result = op(data, kernel)
fn = function([data, kernel], conv_result)
rng = np.random.default_rng((172, kernel_shape, data_shape, sum(map(ord, mode))))
data_val = rng.normal(size=data_shape).astype(data.dtype)
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype)
np.testing.assert_allclose(
fn(data_val, kernel_val),
scipy_convolve2d(
data_val, kernel_val, mode=mode, boundary=boundary, **boundary_kwargs
),
atol=1e-5 if config.floatX == "float32" else 1e-13,
rtol=1e-5 if config.floatX == "float32" else 1e-13,
)
utt.verify_grad(lambda k: op(data_val, k).sum(), [kernel_val])
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
@pytest.mark.parametrize("method", ["direct", "fft"])
def test_batched_1d_agrees_with_2d_row_filter(mode, method):
data = matrix("data")
kernel_1d = vector("kernel_1d")
kernel_2d = expand_dims(kernel_1d, 0)
output_1d = convolve1d(data, kernel_1d, mode=mode)
output_2d = convolve2d(data, kernel_2d, mode=mode, method=method)
grad_1d = grad(output_1d.sum(), kernel_1d).ravel()
grad_2d = grad(output_1d.sum(), kernel_1d).ravel()
fn = function([data, kernel_1d], [output_1d, output_2d, grad_1d, grad_2d])
data_val = np.random.normal(size=(10, 8)).astype(config.floatX)
kernel_1d_val = np.random.normal(size=(3,)).astype(config.floatX)
forward_1d, forward_2d, backward_1d, backward_2d = fn(data_val, kernel_1d_val)
np.testing.assert_allclose(
forward_1d,
forward_2d,
rtol=1e-5 if config.floatX == "float32" else 1e-13,
)
np.testing.assert_allclose(
backward_1d,
backward_2d,
rtol=1e-5 if config.floatX == "float32" else 1e-13,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论