Unverified 提交 afb76951 authored 作者: Trey Wenger's avatar Trey Wenger 提交者: GitHub

Fix indexing in `convolve1d` with `mode="same"` (#1337)

上级 3af923ba
...@@ -119,8 +119,8 @@ def convolve1d( ...@@ -119,8 +119,8 @@ def convolve1d(
if mode == "same": if mode == "same":
# We implement "same" as "valid" with padded `in1`. # We implement "same" as "valid" with padded `in1`.
in1_batch_shape = tuple(in1.shape)[:-1] in1_batch_shape = tuple(in1.shape)[:-1]
zeros_left = in2.shape[0] // 2 zeros_left = in2.shape[-1] // 2
zeros_right = (in2.shape[0] - 1) // 2 zeros_right = (in2.shape[-1] - 1) // 2
in1 = join( in1 = join(
-1, -1,
zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype), zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype),
......
...@@ -47,3 +47,16 @@ def test_convolve1d_batch(): ...@@ -47,3 +47,16 @@ def test_convolve1d_batch():
res_np = np.convolve(x_test[0], y_test[0]) res_np = np.convolve(x_test[0], y_test[0])
np.testing.assert_allclose(res[0], res_np, rtol=rtol) np.testing.assert_allclose(res[0], res_np, rtol=rtol)
np.testing.assert_allclose(res[1], res_np, rtol=rtol) np.testing.assert_allclose(res[1], res_np, rtol=rtol)
def test_convolve1d_batch_same():
x = matrix("data")
y = matrix("kernel")
out = convolve1d(x, y, mode="same")
rng = np.random.default_rng(38)
x_test = rng.normal(size=(2, 8)).astype(x.dtype)
y_test = rng.normal(size=(2, 8)).astype(x.dtype)
res = out.eval({x: x_test, y: y_test})
assert res.shape == (2, 8)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论