提交 6fb515d0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

C implementation of Convolve1d

上级 6557682b
...@@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Literal, cast ...@@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Literal, cast
from numpy import convolve as numpy_convolve from numpy import convolve as numpy_convolve
from pytensor.graph import Apply, Op from pytensor.graph import Apply
from pytensor.link.c.op import COp
from pytensor.scalar.basic import upcast from pytensor.scalar.basic import upcast
from pytensor.tensor.basic import as_tensor_variable, join, zeros from pytensor.tensor.basic import as_tensor_variable, join, zeros
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
...@@ -15,7 +16,7 @@ if TYPE_CHECKING: ...@@ -15,7 +16,7 @@ if TYPE_CHECKING:
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
class Convolve1d(Op): class Convolve1d(COp):
__props__ = ("mode",) __props__ = ("mode",)
gufunc_signature = "(n),(k)->(o)" gufunc_signature = "(n),(k)->(o)"
...@@ -86,6 +87,87 @@ class Convolve1d(Op): ...@@ -86,6 +87,87 @@ class Convolve1d(Op):
return [in1_bar, in2_bar] return [in1_bar, in2_bar]
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
# raise NotImplementedError()
in1, in2 = inputs
[out] = outputs
mode_str = self.mode
if mode_str == "full":
np_mode_val = 2 # NPY_CONVOLVE_FULL
elif mode_str == "valid":
np_mode_val = 0 # NPY_CONVOLVE_VALID
else:
# This case should ideally be prevented by __init__ or make_node
raise ValueError(f"Unsupported mode {mode_str}")
code = f"""
{{
PyArrayObject* in2_flipped_view = NULL;
if (PyArray_NDIM({in1}) != 1 || PyArray_NDIM({in2}) != 1) {{
PyErr_SetString(PyExc_ValueError, "Convolve1d C code expects 1D arrays.");
{sub['fail']};
}}
npy_intp n_in2 = PyArray_DIM({in2}, 0);
// Create a reversed view of in2
if (n_in2 == 0) {{
PyErr_SetString(PyExc_ValueError, "Convolve1d: second input (kernel) cannot be empty.");
{sub['fail']};
}} else {{
npy_intp view_dims[1];
view_dims[0] = n_in2;
npy_intp view_strides[1];
view_strides[0] = -PyArray_STRIDES({in2})[0];
void* view_data = (char*)PyArray_DATA({in2}) + (n_in2 - 1) * PyArray_STRIDES({in2})[0];
Py_INCREF(PyArray_DESCR({in2}));
in2_flipped_view = (PyArrayObject*)PyArray_NewFromDescr(
Py_TYPE({in2}),
PyArray_DESCR({in2}),
1, // ndim
view_dims,
view_strides,
view_data,
(PyArray_FLAGS({in2}) & ~NPY_ARRAY_WRITEABLE),
NULL
);
if (!in2_flipped_view) {{
PyErr_SetString(PyExc_RuntimeError, "Failed to create flipped kernel view for Convolve1d.");
{sub['fail']};
}}
Py_INCREF({in2});
if (PyArray_SetBaseObject(in2_flipped_view, (PyObject*){in2}) < 0) {{
Py_DECREF({in2}); // SetBaseObject failed, release the extra INCREF
Py_DECREF(in2_flipped_view);
in2_flipped_view = NULL;
PyErr_SetString(PyExc_RuntimeError, "Failed to set base object for flipped kernel view in Convolve1d.");
{sub['fail']};
}}
PyArray_UpdateFlags(in2_flipped_view, (NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS));
}}
// TODO: Use lower level implementation that allows reusing the output buffer
Py_XDECREF({out});
{out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {np_mode_val});
Py_XDECREF(in2_flipped_view); // Clean up the view if correlate fails
if (!{out}) {{
// PyArray_Correlate already set an error
{sub['fail']};
}}
}}
"""
return code
def convolve1d( def convolve1d(
in1: "TensorLike", in1: "TensorLike",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论