提交 9530ffcc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename core Conv1d to Convolve1d

上级 35c69991
import jax import jax
from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.signal.conv import Conv1d from pytensor.tensor.signal.conv import Convolve1d
@jax_funcify.register(Conv1d) @jax_funcify.register(Convolve1d)
def jax_funcify_Conv1d(op, node, **kwargs): def jax_funcify_Convolve1d(op, node, **kwargs):
mode = op.mode mode = op.mode
def conv1d(data, kernel): def conv1d(data, kernel):
......
...@@ -2,11 +2,11 @@ import numpy as np ...@@ -2,11 +2,11 @@ import numpy as np
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.signal.conv import Conv1d from pytensor.tensor.signal.conv import Convolve1d
@numba_funcify.register(Conv1d) @numba_funcify.register(Convolve1d)
def numba_funcify_Conv1d(op, node, **kwargs): def numba_funcify_Convolve1d(op, node, **kwargs):
mode = op.mode mode = op.mode
@numba_njit @numba_njit
......
...@@ -15,7 +15,7 @@ if TYPE_CHECKING: ...@@ -15,7 +15,7 @@ if TYPE_CHECKING:
from pytensor.tensor import TensorLike from pytensor.tensor import TensorLike
class Conv1d(Op): class Convolve1d(Op):
__props__ = ("mode",) __props__ = ("mode",)
gufunc_signature = "(n),(k)->(o)" gufunc_signature = "(n),(k)->(o)"
...@@ -129,4 +129,4 @@ def convolve1d( ...@@ -129,4 +129,4 @@ def convolve1d(
) )
mode = "valid" mode = "valid"
return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2)) return cast(TensorVariable, Blockwise(Convolve1d(mode=mode))(in1, in2))
...@@ -8,7 +8,7 @@ from pytensor import config, function, grad ...@@ -8,7 +8,7 @@ from pytensor import config, function, grad
from pytensor.graph import ancestors, rewrite_graph from pytensor.graph import ancestors, rewrite_graph
from pytensor.tensor import matrix, vector from pytensor.tensor import matrix, vector
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Conv1d, convolve1d from pytensor.tensor.signal.conv import Convolve1d, convolve1d
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -81,4 +81,4 @@ def test_convolve1d_batch_graph(mode): ...@@ -81,4 +81,4 @@ def test_convolve1d_batch_graph(mode):
if var.owner is not None and isinstance(var.owner.op, Blockwise) if var.owner is not None and isinstance(var.owner.op, Blockwise)
] ]
# Check any Blockwise are just Conv1d # Check any Blockwise are just Conv1d
assert all(isinstance(node.op.core_op, Conv1d) for node in blockwise_nodes) assert all(isinstance(node.op.core_op, Convolve1d) for node in blockwise_nodes)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论