提交 1d82fb46 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make convolve mode symbolic to avoid unnecessary large convolution in gradient

上级 a62e785d
import jax
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.signal.conv import Convolve1d
@jax_funcify.register(Convolve1d)
def jax_funcify_Convolve1d(op, node, **kwargs):
mode = op.mode
_, _, full_mode = node.inputs
try:
full_mode = get_underlying_scalar_constant_value(full_mode)
except NotScalarConstantError:
raise NotImplementedError(
"Cannot compile Convolve1D to jax without static mode"
)
static_mode = "full" if full_mode else "valid"
def conv1d(data, kernel):
return jax.numpy.convolve(data, kernel, mode=mode)
def conv1d(data, kernel, _runtime_full_mode):
# _runtime_full_mode is not used, as we only support static mode
return jax.numpy.convolve(data, kernel, mode=static_mode)
return conv1d
......@@ -9,13 +9,11 @@ from pytensor.tensor.signal.conv import Convolve1d
@numba_funcify.register(Convolve1d)
def numba_funcify_Convolve1d(op, node, **kwargs):
# This specialized version is faster than the overloaded numba np.convolve
mode = op.mode
a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype
out_dtype = node.outputs[0].type.dtype
innerprod = _get_inner_prod(a_dtype, b_dtype)
if mode == "valid":
@numba_njit
def valid_convolve1d(x, y):
nx = len(x)
ny = len(y)
......@@ -32,10 +30,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return ret
return numba_njit(valid_convolve1d)
elif mode == "full":
@numba_njit
def full_convolve1d(x, y):
nx = len(x)
ny = len(y)
......@@ -64,7 +59,11 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return ret
return numba_njit(full_convolve1d)
@numba_njit
def convolve_1d(x, y, mode):
if mode:
return full_convolve1d(x, y)
else:
raise ValueError(f"Unsupported mode: {mode}")
return valid_convolve1d(x, y)
return convolve_1d
......@@ -360,12 +360,12 @@ class Blockwise(COp):
dummy_fgraph, dummy_core_node, core_input_shapes
)
# Set to None those core_shapes that depend on dummy_core_inputs,
# meaning their value may not be constant across batch dims of the Blockwise
if not dummy_core_inputs:
# All inputs are unbatched, so the core_shape can be used as is
return core_output_shapes
else:
# Set to None those core_shapes that depend on dummy_core_inputs,
# meaning their value may not be constant across batch dims of the Blockwise
set_dummy_core_inputs = set(dummy_core_inputs)
safe_core_output_shapes = [list(shape) for shape in core_output_shapes]
for core_out_shape in safe_core_output_shapes:
......
......@@ -3,7 +3,6 @@ import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.conv
import pytensor.tensor.rewriting.einsum
import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops
......
from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.rewriting.basic import register_specialize, register_stabilize
from pytensor.tensor.signal import convolve1d
from pytensor.tensor.signal.conv import Convolve1d
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
@register_stabilize
@register_specialize
@node_rewriter([Subtensor])
def local_sliced_full_conv_to_valid_conv(fgraph, node):
"""Rewrite sliced full conv that are equivalent to valid.
The gradient of a valid Conv1d always implements the worst case scenario - full convolution -
because it would need to know which input is larger to do something smarter.
If we find out (through rewrites or static shape) we provide the direct implementation
which can be orders of magnitude faster.
# if x.shape[-1] > y.shape[-1]
# z = convolve1d(x, y, mode="full")
# z[..., y.shape[-1] - 1: z.shape[-1] - y.shape[-1] - 1] -> convolve1d(x, y, mode="valid")
"""
conv, *other_idx_vars = node.inputs
if not (
conv.owner is not None
and isinstance(conv.owner.op, Blockwise)
and isinstance(conv.owner.op.core_op, Convolve1d)
and conv.owner.op.core_op.mode == "full"
):
return None
# Check we have an (a:b) constant slice at the last axis of the input
idx_list = node.op.idx_list
if not (len(idx_list) == conv.type.ndim and isinstance(idx_list[-1], slice)):
return None
last_slice = idx_list[-1]
if not (
last_slice.start is not None
and last_slice.stop is not None
and last_slice.step is None
):
return None
*other_idx_vars, start, stop = other_idx_vars
if not (isinstance(start, Constant) and isinstance(stop, Constant)):
return None
x, y = conv.owner.inputs
len_x = x.type.shape[-1]
len_y = y.type.shape[-1]
if len_x is None or len_y is None:
return None
start, stop = start.data, stop.data
if len_x < len_y:
# Convolution is symmetric with input order
x, y = y, x
len_x, len_y = len_y, len_x
if (
start == len_y - 1
# equivalent to stop = conv.shape[-1] - len_y - 1
and stop == start + (len_x - len_y) + 1
):
new_conv = convolve1d(x, y, mode="valid")
copy_stack_trace(conv, new_conv)
if other_idx_vars:
# If there were more than just empty slices besides the last one
new_indices = indices_from_subtensor(idx_list[:-1], other_idx_vars)
new_conv = new_conv[new_indices]
copy_stack_trace(node.out, new_conv)
return [new_conv]
from typing import TYPE_CHECKING, Literal, cast
import numpy as np
from numpy import convolve as numpy_convolve
from pytensor.graph import Apply
from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply, Constant
from pytensor.link.c.op import COp
from pytensor.scalar import as_scalar
from pytensor.scalar.basic import upcast
from pytensor.tensor.basic import as_tensor_variable, join, zeros
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import maximum, minimum
from pytensor.tensor.math import maximum, minimum, switch
from pytensor.tensor.type import vector
from pytensor.tensor.variable import TensorVariable
......@@ -17,92 +20,83 @@ if TYPE_CHECKING:
class Convolve1d(COp):
__props__ = ("mode",)
gufunc_signature = "(n),(k)->(o)"
__props__ = ()
gufunc_signature = "(n),(k),()->(o)"
def __init__(self, mode: Literal["full", "valid"] = "full"):
if mode not in ("full", "valid"):
raise ValueError(f"Invalid mode: {mode}")
self.mode = mode
def make_node(self, in1, in2):
def make_node(self, in1, in2, full_mode):
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)
full_mode = as_scalar(full_mode)
assert in1.ndim == 1
assert in2.ndim == 1
if not (in1.ndim == 1 and in2.ndim == 1):
raise ValueError("Convolution inputs must be vector (ndim=1)")
if not full_mode.dtype == "bool":
raise ValueError("Convolution mode must be a boolean type")
dtype = upcast(in1.dtype, in2.dtype)
n = in1.type.shape[0]
k = in2.type.shape[0]
match full_mode:
case Constant():
static_mode = "full" if full_mode.data else "valid"
case _:
static_mode = None
if n is None or k is None:
if n is None or k is None or static_mode is None:
out_shape = (None,)
elif self.mode == "full":
elif static_mode == "full":
out_shape = (n + k - 1,)
else: # mode == "valid":
out_shape = (max(n, k) - min(n, k) + 1,)
out = vector(dtype=dtype, shape=out_shape)
return Apply(self, [in1, in2], [out])
return Apply(self, [in1, in2, full_mode], [out])
def perform(self, node, inputs, outputs):
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
# And mode != "same", which this Op doesn't cover anyway.
outputs[0][0] = numpy_convolve(*inputs, mode=self.mode)
in1, in2, full_mode = inputs
outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid")
def infer_shape(self, fgraph, node, shapes):
in1_shape, in2_shape = shapes
_, _, full_mode = node.inputs
in1_shape, in2_shape, _ = shapes
n = in1_shape[0]
k = in2_shape[0]
if self.mode == "full":
shape = n + k - 1
else: # mode == "valid":
shape = maximum(n, k) - minimum(n, k) + 1
shape_valid = maximum(n, k) - minimum(n, k) + 1
shape_full = n + k - 1
shape = switch(full_mode, shape_full, shape_valid)
return [[shape]]
def connection_pattern(self, node):
return [[True], [True], [False]]
def L_op(self, inputs, outputs, output_grads):
in1, in2 = inputs
in1, in2, full_mode = inputs
[grad] = output_grads
if self.mode == "full":
valid_conv = type(self)(mode="valid")
in1_bar = valid_conv(grad, in2[::-1])
in2_bar = valid_conv(grad, in1[::-1])
else: # mode == "valid":
full_conv = type(self)(mode="full")
n = in1.shape[0]
k = in2.shape[0]
kmn = maximum(0, k - n)
nmk = maximum(0, n - k)
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
# There is a rewrite that optimizes this case when n, k are static
in1_bar = full_conv(grad, in2[::-1])
in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn]
in2_bar = full_conv(grad, in1[::-1])
in2_bar = in2_bar[nmk : in2_bar.shape[0] - nmk]
return [in1_bar, in2_bar]
# If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
# The expression below is equivalent to ~(full_mode | (k >= n))
full_mode_in1_bar = ~full_mode & (k < n)
# If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
# The expression below is equivalent to ~(full_mode | (n >= k))
full_mode_in2_bar = ~full_mode & (n < k)
return [
self(grad, in2[::-1], full_mode_in1_bar),
self(grad, in1[::-1], full_mode_in2_bar),
DisconnectedType()(),
]
def c_code_cache_version(self):
return (1,)
return (2,)
def c_code(self, node, name, inputs, outputs, sub):
# raise NotImplementedError()
in1, in2 = inputs
in1, in2, full_mode = inputs
[out] = outputs
mode_str = self.mode
if mode_str == "full":
np_mode_val = 2 # NPY_CONVOLVE_FULL
elif mode_str == "valid":
np_mode_val = 0 # NPY_CONVOLVE_VALID
else:
# This case should ideally be prevented by __init__ or make_node
raise ValueError(f"Unsupported mode {mode_str}")
code = f"""
{{
......@@ -158,7 +152,7 @@ class Convolve1d(COp):
// TODO: Use lower level implementation that allows reusing the output buffer
Py_XDECREF({out});
{out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {np_mode_val});
{out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {full_mode} ? 2 : 0);
Py_XDECREF(in2_flipped_view); // Clean up the view if correlate fails
if (!{out}) {{
// PyArray_Correlate already set an error
......@@ -169,6 +163,9 @@ class Convolve1d(COp):
return code
blockwise_convolve_1d = Blockwise(Convolve1d())
def convolve1d(
in1: "TensorLike",
in2: "TensorLike",
......@@ -212,4 +209,5 @@ def convolve1d(
)
mode = "valid"
return cast(TensorVariable, Blockwise(Convolve1d(mode=mode))(in1, in2))
full_mode = as_scalar(np.bool_(mode == "full"))
return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))
......@@ -7,6 +7,7 @@ from pytensor import function
from pytensor.tensor import dmatrix, tensor
from pytensor.tensor.signal import convolve1d
from tests.link.numba.test_basic import compare_numba_and_py
from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker
pytestmark = pytest.mark.filterwarnings("error")
......@@ -31,15 +32,8 @@ def test_convolve1d(x_smaller, mode):
@pytest.mark.parametrize("mode", ("full", "valid"), ids=lambda x: f"mode={x}")
@pytest.mark.parametrize("batch", (False, True), ids=lambda x: f"batch={x}")
def test_convolve1d_benchmark(batch, mode, benchmark):
x = tensor(
shape=(
7,
183,
)
if batch
else (183,)
)
def test_convolve1d_benchmark_numba(batch, mode, benchmark):
x = tensor(shape=(7, 183) if batch else (183,))
y = tensor(shape=(7, 6) if batch else (6,))
out = convolve1d(x, y, mode=mode)
fn = function([x, y], out, mode="NUMBA", trust_input=True)
......@@ -57,3 +51,8 @@ def test_convolve1d_benchmark(batch, mode, benchmark):
np_convolve1d(x_test, y_test),
)
benchmark(fn, x_test, y_test)
@pytest.mark.parametrize("convolve_mode", ["full", "valid"])
def test_convolve1d_grad_benchmark_numba(convolve_mode, benchmark):
convolve1d_grad_benchmarker(convolve_mode, "NUMBA", benchmark)
......@@ -7,7 +7,7 @@ from scipy.signal import convolve as scipy_convolve
from pytensor import config, function, grad
from pytensor.graph.basic import ancestors, io_toposort
from pytensor.graph.rewriting import rewrite_graph
from pytensor.tensor import matrix, vector
from pytensor.tensor import matrix, tensor, vector
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Convolve1d, convolve1d
from tests import unittest_tools as utt
......@@ -86,11 +86,8 @@ def test_convolve1d_batch_graph(mode):
@pytest.mark.parametrize("static_shape", [False, True])
def test_convolve1d_valid_grad_rewrite(static_shape):
"""Test that we don't do a useless full convolve1d when taking the gradient of a valid convolve wrt to the smallest input.
This can only be achieved when the two inputs have static shapes, so we know which one is larger
"""
def test_convolve1d_valid_grad(static_shape):
"""Test we don't do a full convolve in the gradient of the smaller input to a valid convolve."""
larger = vector("larger", shape=(128 if static_shape else None,))
smaller = vector("smaller", shape=(64 if static_shape else None,))
out = convolve1d(larger, smaller, mode="valid")
......@@ -103,9 +100,40 @@ def test_convolve1d_valid_grad_rewrite(static_shape):
"local_useless_unbatched_blockwise",
),
)
[conv_op] = [
node.op
[conv_node] = [
node
for node in io_toposort([larger, smaller], [grad_out])
if isinstance(node.op, Convolve1d)
]
assert conv_op.mode == ("valid" if static_shape else "full")
full_mode = conv_node.inputs[-1]
# If shape is static we get constant mode == "valid", otherwise it depends on the input shapes
# ignoring E712 because np.True_ and np.False_ need to be compared with `==` to produce a valid boolean
if static_shape:
assert full_mode.eval() == False # noqa: E712
else:
dtype = larger.dtype
larger_test = np.zeros((128,), dtype=dtype)
smaller_test = np.zeros((64,), dtype=dtype)
assert full_mode.eval({larger: larger_test, smaller: smaller_test}) == False # noqa: E712
assert full_mode.eval({larger: smaller_test, smaller: larger_test}) == True # noqa: E712
def convolve1d_grad_benchmarker(convolve_mode, mode, benchmark):
# Use None core shape so PyTensor doesn't know which mode to use until runtime.
larger = tensor("larger", shape=(8, None))
smaller = tensor("smaller", shape=(8, None))
grad_wrt_smaller = grad(
convolve1d(larger, smaller, mode=convolve_mode).sum(), wrt=smaller
)
fn = function([larger, smaller], grad_wrt_smaller, trust_input=True, mode=mode)
rng = np.random.default_rng([119, mode == "full"])
test_larger = rng.normal(size=(8, 1024)).astype(larger.type.dtype)
test_smaller = rng.normal(size=(8, 16)).astype(smaller.type.dtype)
benchmark(fn, test_larger, test_smaller)
@pytest.mark.parametrize("convolve_mode", ["full", "valid"])
def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark):
convolve1d_grad_benchmarker(convolve_mode, "FAST_RUN", benchmark)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论