Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4378d482
提交
4378d482
authored
4月 24, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
4月 27, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rewrite sliced full convolutions as valid
These show up in the gradient of Convolve1D
上级
2ada4b66
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
110 行增加
和
3 行删除
+110
-3
__init__.py
pytensor/tensor/rewriting/__init__.py
+1
-0
conv.py
pytensor/tensor/rewriting/conv.py
+78
-0
conv.py
pytensor/tensor/signal/conv.py
+3
-2
test_conv.py
tests/tensor/signal/test_conv.py
+28
-1
没有找到文件。
pytensor/tensor/rewriting/__init__.py
浏览文件 @
4378d482
...
@@ -3,6 +3,7 @@ import pytensor.tensor.rewriting.blas
...
@@ -3,6 +3,7 @@ import pytensor.tensor.rewriting.blas
import
pytensor.tensor.rewriting.blas_c
import
pytensor.tensor.rewriting.blas_c
import
pytensor.tensor.rewriting.blas_scipy
import
pytensor.tensor.rewriting.blas_scipy
import
pytensor.tensor.rewriting.blockwise
import
pytensor.tensor.rewriting.blockwise
import
pytensor.tensor.rewriting.conv
import
pytensor.tensor.rewriting.einsum
import
pytensor.tensor.rewriting.einsum
import
pytensor.tensor.rewriting.elemwise
import
pytensor.tensor.rewriting.elemwise
import
pytensor.tensor.rewriting.extra_ops
import
pytensor.tensor.rewriting.extra_ops
...
...
pytensor/tensor/rewriting/conv.py
0 → 100644
浏览文件 @
4378d482
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
node_rewriter
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.rewriting.basic
import
register_specialize
,
register_stabilize
from
pytensor.tensor.signal
import
convolve1d
from
pytensor.tensor.signal.conv
import
Convolve1d
from
pytensor.tensor.subtensor
import
Subtensor
,
indices_from_subtensor
@register_stabilize
@register_specialize
@node_rewriter
([
Subtensor
])
def
local_sliced_full_conv_to_valid_conv
(
fgraph
,
node
):
"""Rewrite sliced full conv that are equivalent to valid.
The gradient of a valid Conv1d always implements the worst case scenario - full convolution -
because it would need to know which input is larger to do something smarter.
If we find out (through rewrites or static shape) we provide the direct implementation
which can be orders of magnitude faster.
# if x.shape[-1] > y.shape[-1]
# z = convolve1d(x, y, mode="full")
# z[..., y.shape[-1] - 1: z.shape[-1] - y.shape[-1] - 1] -> convolve1d(x, y, mode="valid")
"""
conv
,
*
other_idx_vars
=
node
.
inputs
if
not
(
conv
.
owner
is
not
None
and
isinstance
(
conv
.
owner
.
op
,
Blockwise
)
and
isinstance
(
conv
.
owner
.
op
.
core_op
,
Convolve1d
)
and
conv
.
owner
.
op
.
core_op
.
mode
==
"full"
):
return
None
# Check we have an (a:b) constant slice at the last axis of the input
idx_list
=
node
.
op
.
idx_list
if
not
(
len
(
idx_list
)
==
conv
.
type
.
ndim
and
isinstance
(
idx_list
[
-
1
],
slice
)):
return
None
last_slice
=
idx_list
[
-
1
]
if
not
(
last_slice
.
start
is
not
None
and
last_slice
.
stop
is
not
None
and
last_slice
.
step
is
None
):
return
None
*
other_idx_vars
,
start
,
stop
=
other_idx_vars
if
not
(
isinstance
(
start
,
Constant
)
and
isinstance
(
stop
,
Constant
)):
return
None
x
,
y
=
conv
.
owner
.
inputs
len_x
=
x
.
type
.
shape
[
-
1
]
len_y
=
y
.
type
.
shape
[
-
1
]
if
len_x
is
None
or
len_y
is
None
:
return
None
start
,
stop
=
start
.
data
,
stop
.
data
if
len_x
<
len_y
:
# Convolution is symmetric with input order
x
,
y
=
y
,
x
len_x
,
len_y
=
len_y
,
len_x
if
(
start
==
len_y
-
1
# equivalent to stop = conv.shape[-1] - len_y - 1
and
stop
==
start
+
(
len_x
-
len_y
)
+
1
):
new_conv
=
convolve1d
(
x
,
y
,
mode
=
"valid"
)
copy_stack_trace
(
conv
,
new_conv
)
if
other_idx_vars
:
# If there were more than just empty slices besides the last one
new_indices
=
indices_from_subtensor
(
idx_list
[:
-
1
],
other_idx_vars
)
new_conv
=
new_conv
[
new_indices
]
copy_stack_trace
(
node
.
out
,
new_conv
)
return
[
new_conv
]
pytensor/tensor/signal/conv.py
浏览文件 @
4378d482
...
@@ -75,13 +75,14 @@ class Convolve1d(Op):
...
@@ -75,13 +75,14 @@ class Convolve1d(Op):
n
=
in1
.
shape
[
0
]
n
=
in1
.
shape
[
0
]
k
=
in2
.
shape
[
0
]
k
=
in2
.
shape
[
0
]
kmn
=
maximum
(
0
,
k
-
n
)
kmn
=
maximum
(
0
,
k
-
n
)
n
km
=
maximum
(
0
,
n
-
k
)
n
mk
=
maximum
(
0
,
n
-
k
)
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
# There is a rewrite that optimizes this case when n, k are static
in1_bar
=
full_conv
(
grad
,
in2
[::
-
1
])
in1_bar
=
full_conv
(
grad
,
in2
[::
-
1
])
in1_bar
=
in1_bar
[
kmn
:
in1_bar
.
shape
[
0
]
-
kmn
]
in1_bar
=
in1_bar
[
kmn
:
in1_bar
.
shape
[
0
]
-
kmn
]
in2_bar
=
full_conv
(
grad
,
in1
[::
-
1
])
in2_bar
=
full_conv
(
grad
,
in1
[::
-
1
])
in2_bar
=
in2_bar
[
n
km
:
in2_bar
.
shape
[
0
]
-
nkm
]
in2_bar
=
in2_bar
[
n
mk
:
in2_bar
.
shape
[
0
]
-
nmk
]
return
[
in1_bar
,
in2_bar
]
return
[
in1_bar
,
in2_bar
]
...
...
tests/tensor/signal/test_conv.py
浏览文件 @
4378d482
...
@@ -5,7 +5,8 @@ import pytest
...
@@ -5,7 +5,8 @@ import pytest
from
scipy.signal
import
convolve
as
scipy_convolve
from
scipy.signal
import
convolve
as
scipy_convolve
from
pytensor
import
config
,
function
,
grad
from
pytensor
import
config
,
function
,
grad
from
pytensor.graph
import
ancestors
,
rewrite_graph
from
pytensor.graph.basic
import
ancestors
,
io_toposort
from
pytensor.graph.rewriting
import
rewrite_graph
from
pytensor.tensor
import
matrix
,
vector
from
pytensor.tensor
import
matrix
,
vector
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.signal.conv
import
Convolve1d
,
convolve1d
from
pytensor.tensor.signal.conv
import
Convolve1d
,
convolve1d
...
@@ -82,3 +83,29 @@ def test_convolve1d_batch_graph(mode):
...
@@ -82,3 +83,29 @@ def test_convolve1d_batch_graph(mode):
]
]
# Check any Blockwise are just Conv1d
# Check any Blockwise are just Conv1d
assert
all
(
isinstance
(
node
.
op
.
core_op
,
Convolve1d
)
for
node
in
blockwise_nodes
)
assert
all
(
isinstance
(
node
.
op
.
core_op
,
Convolve1d
)
for
node
in
blockwise_nodes
)
@pytest.mark.parametrize
(
"static_shape"
,
[
False
,
True
])
def
test_convolve1d_valid_grad_rewrite
(
static_shape
):
"""Test that we don't do a useless full convolve1d when taking the gradient of a valid convolve wrt to the smallest input.
This can only be achieved when the two inputs have static shapes, so we know which one is larger
"""
larger
=
vector
(
"larger"
,
shape
=
(
128
if
static_shape
else
None
,))
smaller
=
vector
(
"smaller"
,
shape
=
(
64
if
static_shape
else
None
,))
out
=
convolve1d
(
larger
,
smaller
,
mode
=
"valid"
)
grad_out
=
rewrite_graph
(
grad
(
out
.
sum
(),
wrt
=
smaller
),
include
=
(
"ShapeOpt"
,
"canonicalize"
,
"stabilize"
,
"local_useless_unbatched_blockwise"
,
),
)
[
conv_op
]
=
[
node
.
op
for
node
in
io_toposort
([
larger
,
smaller
],
[
grad_out
])
if
isinstance
(
node
.
op
,
Convolve1d
)
]
assert
conv_op
.
mode
==
(
"valid"
if
static_shape
else
"full"
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论