Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1d82fb46
提交
1d82fb46
authored
7月 04, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 08, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make convolve mode symbolic to avoid unnecessary large convolution in gradient
上级
a62e785d
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
122 行增加
和
167 行删除
+122
-167
conv.py
pytensor/link/jax/dispatch/signal/conv.py
+13
-3
conv.py
pytensor/link/numba/dispatch/signal/conv.py
+9
-10
blockwise.py
pytensor/tensor/blockwise.py
+2
-2
__init__.py
pytensor/tensor/rewriting/__init__.py
+0
-1
conv.py
pytensor/tensor/rewriting/conv.py
+0
-78
conv.py
pytensor/tensor/signal/conv.py
+53
-55
test_conv.py
tests/link/numba/signal/test_conv.py
+8
-9
test_conv.py
tests/tensor/signal/test_conv.py
+37
-9
没有找到文件。
pytensor/link/jax/dispatch/signal/conv.py
浏览文件 @
1d82fb46
import
jax
from
pytensor.link.jax.dispatch
import
jax_funcify
from
pytensor.tensor.basic
import
get_underlying_scalar_constant_value
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.signal.conv
import
Convolve1d
@jax_funcify.register
(
Convolve1d
)
def
jax_funcify_Convolve1d
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
_
,
_
,
full_mode
=
node
.
inputs
try
:
full_mode
=
get_underlying_scalar_constant_value
(
full_mode
)
except
NotScalarConstantError
:
raise
NotImplementedError
(
"Cannot compile Convolve1D to jax without static mode"
)
static_mode
=
"full"
if
full_mode
else
"valid"
def
conv1d
(
data
,
kernel
):
return
jax
.
numpy
.
convolve
(
data
,
kernel
,
mode
=
mode
)
def
conv1d
(
data
,
kernel
,
_runtime_full_mode
):
# _runtime_full_mode is not used, as we only support static mode
return
jax
.
numpy
.
convolve
(
data
,
kernel
,
mode
=
static_mode
)
return
conv1d
pytensor/link/numba/dispatch/signal/conv.py
浏览文件 @
1d82fb46
...
...
@@ -9,13 +9,11 @@ from pytensor.tensor.signal.conv import Convolve1d
@numba_funcify.register
(
Convolve1d
)
def
numba_funcify_Convolve1d
(
op
,
node
,
**
kwargs
):
# This specialized version is faster than the overloaded numba np.convolve
mode
=
op
.
mode
a_dtype
,
b_dtype
=
node
.
inputs
[
0
]
.
type
.
dtype
,
node
.
inputs
[
1
]
.
type
.
dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
innerprod
=
_get_inner_prod
(
a_dtype
,
b_dtype
)
if
mode
==
"valid"
:
@numba_njit
def
valid_convolve1d
(
x
,
y
):
nx
=
len
(
x
)
ny
=
len
(
y
)
...
...
@@ -32,10 +30,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return
ret
return
numba_njit
(
valid_convolve1d
)
elif
mode
==
"full"
:
@numba_njit
def
full_convolve1d
(
x
,
y
):
nx
=
len
(
x
)
ny
=
len
(
y
)
...
...
@@ -64,7 +59,11 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
return
ret
return
numba_njit
(
full_convolve1d
)
@numba_njit
def
convolve_1d
(
x
,
y
,
mode
):
if
mode
:
return
full_convolve1d
(
x
,
y
)
else
:
raise
ValueError
(
f
"Unsupported mode: {mode}"
)
return
valid_convolve1d
(
x
,
y
)
return
convolve_1d
pytensor/tensor/blockwise.py
浏览文件 @
1d82fb46
...
...
@@ -360,12 +360,12 @@ class Blockwise(COp):
dummy_fgraph
,
dummy_core_node
,
core_input_shapes
)
# Set to None those core_shapes that depend on dummy_core_inputs,
# meaning their value may not be constant across batch dims of the Blockwise
if
not
dummy_core_inputs
:
# All inputs are unbatched, so the core_shape can be used as is
return
core_output_shapes
else
:
# Set to None those core_shapes that depend on dummy_core_inputs,
# meaning their value may not be constant across batch dims of the Blockwise
set_dummy_core_inputs
=
set
(
dummy_core_inputs
)
safe_core_output_shapes
=
[
list
(
shape
)
for
shape
in
core_output_shapes
]
for
core_out_shape
in
safe_core_output_shapes
:
...
...
pytensor/tensor/rewriting/__init__.py
浏览文件 @
1d82fb46
...
...
@@ -3,7 +3,6 @@ import pytensor.tensor.rewriting.blas
import
pytensor.tensor.rewriting.blas_c
import
pytensor.tensor.rewriting.blas_scipy
import
pytensor.tensor.rewriting.blockwise
import
pytensor.tensor.rewriting.conv
import
pytensor.tensor.rewriting.einsum
import
pytensor.tensor.rewriting.elemwise
import
pytensor.tensor.rewriting.extra_ops
...
...
pytensor/tensor/rewriting/conv.py
deleted
100644 → 0
浏览文件 @
a62e785d
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
node_rewriter
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.rewriting.basic
import
register_specialize
,
register_stabilize
from
pytensor.tensor.signal
import
convolve1d
from
pytensor.tensor.signal.conv
import
Convolve1d
from
pytensor.tensor.subtensor
import
Subtensor
,
indices_from_subtensor
@register_stabilize
@register_specialize
@node_rewriter
([
Subtensor
])
def
local_sliced_full_conv_to_valid_conv
(
fgraph
,
node
):
"""Rewrite sliced full conv that are equivalent to valid.
The gradient of a valid Conv1d always implements the worst case scenario - full convolution -
because it would need to know which input is larger to do something smarter.
If we find out (through rewrites or static shape) we provide the direct implementation
which can be orders of magnitude faster.
# if x.shape[-1] > y.shape[-1]
# z = convolve1d(x, y, mode="full")
# z[..., y.shape[-1] - 1: z.shape[-1] - y.shape[-1] - 1] -> convolve1d(x, y, mode="valid")
"""
conv
,
*
other_idx_vars
=
node
.
inputs
if
not
(
conv
.
owner
is
not
None
and
isinstance
(
conv
.
owner
.
op
,
Blockwise
)
and
isinstance
(
conv
.
owner
.
op
.
core_op
,
Convolve1d
)
and
conv
.
owner
.
op
.
core_op
.
mode
==
"full"
):
return
None
# Check we have an (a:b) constant slice at the last axis of the input
idx_list
=
node
.
op
.
idx_list
if
not
(
len
(
idx_list
)
==
conv
.
type
.
ndim
and
isinstance
(
idx_list
[
-
1
],
slice
)):
return
None
last_slice
=
idx_list
[
-
1
]
if
not
(
last_slice
.
start
is
not
None
and
last_slice
.
stop
is
not
None
and
last_slice
.
step
is
None
):
return
None
*
other_idx_vars
,
start
,
stop
=
other_idx_vars
if
not
(
isinstance
(
start
,
Constant
)
and
isinstance
(
stop
,
Constant
)):
return
None
x
,
y
=
conv
.
owner
.
inputs
len_x
=
x
.
type
.
shape
[
-
1
]
len_y
=
y
.
type
.
shape
[
-
1
]
if
len_x
is
None
or
len_y
is
None
:
return
None
start
,
stop
=
start
.
data
,
stop
.
data
if
len_x
<
len_y
:
# Convolution is symmetric with input order
x
,
y
=
y
,
x
len_x
,
len_y
=
len_y
,
len_x
if
(
start
==
len_y
-
1
# equivalent to stop = conv.shape[-1] - len_y - 1
and
stop
==
start
+
(
len_x
-
len_y
)
+
1
):
new_conv
=
convolve1d
(
x
,
y
,
mode
=
"valid"
)
copy_stack_trace
(
conv
,
new_conv
)
if
other_idx_vars
:
# If there were more than just empty slices besides the last one
new_indices
=
indices_from_subtensor
(
idx_list
[:
-
1
],
other_idx_vars
)
new_conv
=
new_conv
[
new_indices
]
copy_stack_trace
(
node
.
out
,
new_conv
)
return
[
new_conv
]
pytensor/tensor/signal/conv.py
浏览文件 @
1d82fb46
from
typing
import
TYPE_CHECKING
,
Literal
,
cast
import
numpy
as
np
from
numpy
import
convolve
as
numpy_convolve
from
pytensor.graph
import
Apply
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph
import
Apply
,
Constant
from
pytensor.link.c.op
import
COp
from
pytensor.scalar
import
as_scalar
from
pytensor.scalar.basic
import
upcast
from
pytensor.tensor.basic
import
as_tensor_variable
,
join
,
zeros
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.math
import
maximum
,
minimum
from
pytensor.tensor.math
import
maximum
,
minimum
,
switch
from
pytensor.tensor.type
import
vector
from
pytensor.tensor.variable
import
TensorVariable
...
...
@@ -17,92 +20,83 @@ if TYPE_CHECKING:
class
Convolve1d
(
COp
):
__props__
=
(
"mode"
,
)
gufunc_signature
=
"(n),(k)->(o)"
__props__
=
()
gufunc_signature
=
"(n),(k)
,()
->(o)"
def
__init__
(
self
,
mode
:
Literal
[
"full"
,
"valid"
]
=
"full"
):
if
mode
not
in
(
"full"
,
"valid"
):
raise
ValueError
(
f
"Invalid mode: {mode}"
)
self
.
mode
=
mode
def
make_node
(
self
,
in1
,
in2
):
def
make_node
(
self
,
in1
,
in2
,
full_mode
):
in1
=
as_tensor_variable
(
in1
)
in2
=
as_tensor_variable
(
in2
)
full_mode
=
as_scalar
(
full_mode
)
assert
in1
.
ndim
==
1
assert
in2
.
ndim
==
1
if
not
(
in1
.
ndim
==
1
and
in2
.
ndim
==
1
):
raise
ValueError
(
"Convolution inputs must be vector (ndim=1)"
)
if
not
full_mode
.
dtype
==
"bool"
:
raise
ValueError
(
"Convolution mode must be a boolean type"
)
dtype
=
upcast
(
in1
.
dtype
,
in2
.
dtype
)
n
=
in1
.
type
.
shape
[
0
]
k
=
in2
.
type
.
shape
[
0
]
match
full_mode
:
case
Constant
():
static_mode
=
"full"
if
full_mode
.
data
else
"valid"
case
_
:
static_mode
=
None
if
n
is
None
or
k
is
None
:
if
n
is
None
or
k
is
None
or
static_mode
is
None
:
out_shape
=
(
None
,)
elif
s
elf
.
mode
==
"full"
:
elif
s
tatic_
mode
==
"full"
:
out_shape
=
(
n
+
k
-
1
,)
else
:
# mode == "valid":
out_shape
=
(
max
(
n
,
k
)
-
min
(
n
,
k
)
+
1
,)
out
=
vector
(
dtype
=
dtype
,
shape
=
out_shape
)
return
Apply
(
self
,
[
in1
,
in2
],
[
out
])
return
Apply
(
self
,
[
in1
,
in2
,
full_mode
],
[
out
])
def
perform
(
self
,
node
,
inputs
,
outputs
):
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
# And mode != "same", which this Op doesn't cover anyway.
outputs
[
0
][
0
]
=
numpy_convolve
(
*
inputs
,
mode
=
self
.
mode
)
in1
,
in2
,
full_mode
=
inputs
outputs
[
0
][
0
]
=
numpy_convolve
(
in1
,
in2
,
mode
=
"full"
if
full_mode
else
"valid"
)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
in1_shape
,
in2_shape
=
shapes
_
,
_
,
full_mode
=
node
.
inputs
in1_shape
,
in2_shape
,
_
=
shapes
n
=
in1_shape
[
0
]
k
=
in2_shape
[
0
]
if
self
.
mode
==
"full"
:
shape
=
n
+
k
-
1
else
:
# mode == "valid":
shape
=
maximum
(
n
,
k
)
-
minimum
(
n
,
k
)
+
1
shape_valid
=
maximum
(
n
,
k
)
-
minimum
(
n
,
k
)
+
1
shape_full
=
n
+
k
-
1
shape
=
switch
(
full_mode
,
shape_full
,
shape_valid
)
return
[[
shape
]]
def
connection_pattern
(
self
,
node
):
return
[[
True
],
[
True
],
[
False
]]
def
L_op
(
self
,
inputs
,
outputs
,
output_grads
):
in1
,
in2
=
inputs
in1
,
in2
,
full_mode
=
inputs
[
grad
]
=
output_grads
if
self
.
mode
==
"full"
:
valid_conv
=
type
(
self
)(
mode
=
"valid"
)
in1_bar
=
valid_conv
(
grad
,
in2
[::
-
1
])
in2_bar
=
valid_conv
(
grad
,
in1
[::
-
1
])
else
:
# mode == "valid":
full_conv
=
type
(
self
)(
mode
=
"full"
)
n
=
in1
.
shape
[
0
]
k
=
in2
.
shape
[
0
]
kmn
=
maximum
(
0
,
k
-
n
)
nmk
=
maximum
(
0
,
n
-
k
)
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
# There is a rewrite that optimizes this case when n, k are static
in1_bar
=
full_conv
(
grad
,
in2
[::
-
1
])
in1_bar
=
in1_bar
[
kmn
:
in1_bar
.
shape
[
0
]
-
kmn
]
in2_bar
=
full_conv
(
grad
,
in1
[::
-
1
])
in2_bar
=
in2_bar
[
nmk
:
in2_bar
.
shape
[
0
]
-
nmk
]
return
[
in1_bar
,
in2_bar
]
# If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
# The expression below is equivalent to ~(full_mode | (k >= n))
full_mode_in1_bar
=
~
full_mode
&
(
k
<
n
)
# If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
# The expression below is equivalent to ~(full_mode | (n >= k))
full_mode_in2_bar
=
~
full_mode
&
(
n
<
k
)
return
[
self
(
grad
,
in2
[::
-
1
],
full_mode_in1_bar
),
self
(
grad
,
in1
[::
-
1
],
full_mode_in2_bar
),
DisconnectedType
()(),
]
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
2
,)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
# raise NotImplementedError()
in1
,
in2
=
inputs
in1
,
in2
,
full_mode
=
inputs
[
out
]
=
outputs
mode_str
=
self
.
mode
if
mode_str
==
"full"
:
np_mode_val
=
2
# NPY_CONVOLVE_FULL
elif
mode_str
==
"valid"
:
np_mode_val
=
0
# NPY_CONVOLVE_VALID
else
:
# This case should ideally be prevented by __init__ or make_node
raise
ValueError
(
f
"Unsupported mode {mode_str}"
)
code
=
f
"""
{{
...
...
@@ -158,7 +152,7 @@ class Convolve1d(COp):
// TODO: Use lower level implementation that allows reusing the output buffer
Py_XDECREF({out});
{out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {
np_mode_val}
);
{out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {
full_mode} ? 2 : 0
);
Py_XDECREF(in2_flipped_view); // Clean up the view if correlate fails
if (!{out}) {{
// PyArray_Correlate already set an error
...
...
@@ -169,6 +163,9 @@ class Convolve1d(COp):
return
code
blockwise_convolve_1d
=
Blockwise
(
Convolve1d
())
def
convolve1d
(
in1
:
"TensorLike"
,
in2
:
"TensorLike"
,
...
...
@@ -212,4 +209,5 @@ def convolve1d(
)
mode
=
"valid"
return
cast
(
TensorVariable
,
Blockwise
(
Convolve1d
(
mode
=
mode
))(
in1
,
in2
))
full_mode
=
as_scalar
(
np
.
bool_
(
mode
==
"full"
))
return
cast
(
TensorVariable
,
blockwise_convolve_1d
(
in1
,
in2
,
full_mode
))
tests/link/numba/signal/test_conv.py
浏览文件 @
1d82fb46
...
...
@@ -7,6 +7,7 @@ from pytensor import function
from
pytensor.tensor
import
dmatrix
,
tensor
from
pytensor.tensor.signal
import
convolve1d
from
tests.link.numba.test_basic
import
compare_numba_and_py
from
tests.tensor.signal.test_conv
import
convolve1d_grad_benchmarker
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
...
...
@@ -31,15 +32,8 @@ def test_convolve1d(x_smaller, mode):
@pytest.mark.parametrize
(
"mode"
,
(
"full"
,
"valid"
),
ids
=
lambda
x
:
f
"mode={x}"
)
@pytest.mark.parametrize
(
"batch"
,
(
False
,
True
),
ids
=
lambda
x
:
f
"batch={x}"
)
def
test_convolve1d_benchmark
(
batch
,
mode
,
benchmark
):
x
=
tensor
(
shape
=
(
7
,
183
,
)
if
batch
else
(
183
,)
)
def
test_convolve1d_benchmark_numba
(
batch
,
mode
,
benchmark
):
x
=
tensor
(
shape
=
(
7
,
183
)
if
batch
else
(
183
,))
y
=
tensor
(
shape
=
(
7
,
6
)
if
batch
else
(
6
,))
out
=
convolve1d
(
x
,
y
,
mode
=
mode
)
fn
=
function
([
x
,
y
],
out
,
mode
=
"NUMBA"
,
trust_input
=
True
)
...
...
@@ -57,3 +51,8 @@ def test_convolve1d_benchmark(batch, mode, benchmark):
np_convolve1d
(
x_test
,
y_test
),
)
benchmark
(
fn
,
x_test
,
y_test
)
@pytest.mark.parametrize
(
"convolve_mode"
,
[
"full"
,
"valid"
])
def
test_convolve1d_grad_benchmark_numba
(
convolve_mode
,
benchmark
):
convolve1d_grad_benchmarker
(
convolve_mode
,
"NUMBA"
,
benchmark
)
tests/tensor/signal/test_conv.py
浏览文件 @
1d82fb46
...
...
@@ -7,7 +7,7 @@ from scipy.signal import convolve as scipy_convolve
from
pytensor
import
config
,
function
,
grad
from
pytensor.graph.basic
import
ancestors
,
io_toposort
from
pytensor.graph.rewriting
import
rewrite_graph
from
pytensor.tensor
import
matrix
,
vector
from
pytensor.tensor
import
matrix
,
tensor
,
vector
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.signal.conv
import
Convolve1d
,
convolve1d
from
tests
import
unittest_tools
as
utt
...
...
@@ -86,11 +86,8 @@ def test_convolve1d_batch_graph(mode):
@pytest.mark.parametrize
(
"static_shape"
,
[
False
,
True
])
def
test_convolve1d_valid_grad_rewrite
(
static_shape
):
"""Test that we don't do a useless full convolve1d when taking the gradient of a valid convolve wrt to the smallest input.
This can only be achieved when the two inputs have static shapes, so we know which one is larger
"""
def
test_convolve1d_valid_grad
(
static_shape
):
"""Test we don't do a full convolve in the gradient of the smaller input to a valid convolve."""
larger
=
vector
(
"larger"
,
shape
=
(
128
if
static_shape
else
None
,))
smaller
=
vector
(
"smaller"
,
shape
=
(
64
if
static_shape
else
None
,))
out
=
convolve1d
(
larger
,
smaller
,
mode
=
"valid"
)
...
...
@@ -103,9 +100,40 @@ def test_convolve1d_valid_grad_rewrite(static_shape):
"local_useless_unbatched_blockwise"
,
),
)
[
conv_
op
]
=
[
node
.
op
[
conv_
node
]
=
[
node
for
node
in
io_toposort
([
larger
,
smaller
],
[
grad_out
])
if
isinstance
(
node
.
op
,
Convolve1d
)
]
assert
conv_op
.
mode
==
(
"valid"
if
static_shape
else
"full"
)
full_mode
=
conv_node
.
inputs
[
-
1
]
# If shape is static we get constant mode == "valid", otherwise it depends on the input shapes
# ignoring E712 because np.True_ and np.False_ need to be compared with `==` to produce a valid boolean
if
static_shape
:
assert
full_mode
.
eval
()
==
False
# noqa: E712
else
:
dtype
=
larger
.
dtype
larger_test
=
np
.
zeros
((
128
,),
dtype
=
dtype
)
smaller_test
=
np
.
zeros
((
64
,),
dtype
=
dtype
)
assert
full_mode
.
eval
({
larger
:
larger_test
,
smaller
:
smaller_test
})
==
False
# noqa: E712
assert
full_mode
.
eval
({
larger
:
smaller_test
,
smaller
:
larger_test
})
==
True
# noqa: E712
def
convolve1d_grad_benchmarker
(
convolve_mode
,
mode
,
benchmark
):
# Use None core shape so PyTensor doesn't know which mode to use until runtime.
larger
=
tensor
(
"larger"
,
shape
=
(
8
,
None
))
smaller
=
tensor
(
"smaller"
,
shape
=
(
8
,
None
))
grad_wrt_smaller
=
grad
(
convolve1d
(
larger
,
smaller
,
mode
=
convolve_mode
)
.
sum
(),
wrt
=
smaller
)
fn
=
function
([
larger
,
smaller
],
grad_wrt_smaller
,
trust_input
=
True
,
mode
=
mode
)
rng
=
np
.
random
.
default_rng
([
119
,
mode
==
"full"
])
test_larger
=
rng
.
normal
(
size
=
(
8
,
1024
))
.
astype
(
larger
.
type
.
dtype
)
test_smaller
=
rng
.
normal
(
size
=
(
8
,
16
))
.
astype
(
smaller
.
type
.
dtype
)
benchmark
(
fn
,
test_larger
,
test_smaller
)
@pytest.mark.parametrize
(
"convolve_mode"
,
[
"full"
,
"valid"
])
def
test_convolve1d_grad_benchmark_c
(
convolve_mode
,
benchmark
):
convolve1d_grad_benchmarker
(
convolve_mode
,
"FAST_RUN"
,
benchmark
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论