提交 4378d482 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rewrite sliced full convolutions as valid

These show up in the gradient of Convolve1D
上级 2ada4b66
......@@ -3,6 +3,7 @@ import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.conv
import pytensor.tensor.rewriting.einsum
import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops
......
from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.rewriting.basic import register_specialize, register_stabilize
from pytensor.tensor.signal import convolve1d
from pytensor.tensor.signal.conv import Convolve1d
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
@register_stabilize
@register_specialize
@node_rewriter([Subtensor])
def local_sliced_full_conv_to_valid_conv(fgraph, node):
"""Rewrite sliced full conv that are equivalent to valid.
The gradient of a valid Conv1d always implements the worst case scenario - full convolution -
because it would need to know which input is larger to do something smarter.
If we find out (through rewrites or static shape) we provide the direct implementation
which can be orders of magnitude faster.
# if x.shape[-1] > y.shape[-1]
# z = convolve1d(x, y, mode="full")
# z[..., y.shape[-1] - 1: z.shape[-1] - y.shape[-1] - 1] -> convolve1d(x, y, mode="valid")
"""
conv, *other_idx_vars = node.inputs
if not (
conv.owner is not None
and isinstance(conv.owner.op, Blockwise)
and isinstance(conv.owner.op.core_op, Convolve1d)
and conv.owner.op.core_op.mode == "full"
):
return None
# Check we have an (a:b) constant slice at the last axis of the input
idx_list = node.op.idx_list
if not (len(idx_list) == conv.type.ndim and isinstance(idx_list[-1], slice)):
return None
last_slice = idx_list[-1]
if not (
last_slice.start is not None
and last_slice.stop is not None
and last_slice.step is None
):
return None
*other_idx_vars, start, stop = other_idx_vars
if not (isinstance(start, Constant) and isinstance(stop, Constant)):
return None
x, y = conv.owner.inputs
len_x = x.type.shape[-1]
len_y = y.type.shape[-1]
if len_x is None or len_y is None:
return None
start, stop = start.data, stop.data
if len_x < len_y:
# Convolution is symmetric with input order
x, y = y, x
len_x, len_y = len_y, len_x
if (
start == len_y - 1
# equivalent to stop = conv.shape[-1] - len_y - 1
and stop == start + (len_x - len_y) + 1
):
new_conv = convolve1d(x, y, mode="valid")
copy_stack_trace(conv, new_conv)
if other_idx_vars:
# If there were more than just empty slices besides the last one
new_indices = indices_from_subtensor(idx_list[:-1], other_idx_vars)
new_conv = new_conv[new_indices]
copy_stack_trace(node.out, new_conv)
return [new_conv]
......@@ -75,13 +75,14 @@ class Convolve1d(Op):
n = in1.shape[0]
k = in2.shape[0]
kmn = maximum(0, k - n)
nkm = maximum(0, n - k)
nmk = maximum(0, n - k)
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
# There is a rewrite that optimizes this case when n, k are static
in1_bar = full_conv(grad, in2[::-1])
in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn]
in2_bar = full_conv(grad, in1[::-1])
in2_bar = in2_bar[nkm : in2_bar.shape[0] - nkm]
in2_bar = in2_bar[nmk : in2_bar.shape[0] - nmk]
return [in1_bar, in2_bar]
......
......@@ -5,7 +5,8 @@ import pytest
from scipy.signal import convolve as scipy_convolve
from pytensor import config, function, grad
from pytensor.graph import ancestors, rewrite_graph
from pytensor.graph.basic import ancestors, io_toposort
from pytensor.graph.rewriting import rewrite_graph
from pytensor.tensor import matrix, vector
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Convolve1d, convolve1d
......@@ -82,3 +83,29 @@ def test_convolve1d_batch_graph(mode):
]
# Check any Blockwise are just Conv1d
assert all(isinstance(node.op.core_op, Convolve1d) for node in blockwise_nodes)
@pytest.mark.parametrize("static_shape", [False, True])
def test_convolve1d_valid_grad_rewrite(static_shape):
"""Test that we don't do a useless full convolve1d when taking the gradient of a valid convolve wrt to the smallest input.
This can only be achieved when the two inputs have static shapes, so we know which one is larger
"""
larger = vector("larger", shape=(128 if static_shape else None,))
smaller = vector("smaller", shape=(64 if static_shape else None,))
out = convolve1d(larger, smaller, mode="valid")
grad_out = rewrite_graph(
grad(out.sum(), wrt=smaller),
include=(
"ShapeOpt",
"canonicalize",
"stabilize",
"local_useless_unbatched_blockwise",
),
)
[conv_op] = [
node.op
for node in io_toposort([larger, smaller], [grad_out])
if isinstance(node.op, Convolve1d)
]
assert conv_op.mode == ("valid" if static_shape else "full")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论