Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0b731c27
Unverified
提交
0b731c27
authored
11月 14, 2025
作者:
Tat Chan
提交者:
GitHub
11月 14, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rewrite concatenate([x, x]) as tile (#1714)
上级
ee568260
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
124 行增加
和
30 行删除
+124
-30
basic.py
pytensor/tensor/rewriting/basic.py
+48
-0
test_basic.py
tests/tensor/rewriting/test_basic.py
+76
-11
test_basic.py
tests/tensor/test_basic.py
+0
-19
没有找到文件。
pytensor/tensor/rewriting/basic.py
浏览文件 @
0b731c27
...
@@ -77,6 +77,7 @@ from pytensor.tensor.basic import (
...
@@ -77,6 +77,7 @@ from pytensor.tensor.basic import (
register_infer_shape
,
register_infer_shape
,
switch
,
switch
,
tensor_copy
,
tensor_copy
,
tile
,
zeros
,
zeros
,
zeros_like
,
zeros_like
,
)
)
...
@@ -910,6 +911,53 @@ def local_join_make_vector(fgraph, node):
...
@@ -910,6 +911,53 @@ def local_join_make_vector(fgraph, node):
return
[
ret
]
return
[
ret
]
@register_canonicalize
@node_rewriter
([
Join
])
def
local_join_to_repeat
(
fgraph
,
node
):
"""Join(axis, x, x, x, ...) -> tile(x, reps)
When the same tensor is concatenated multiple times along an axis,
replace with a single tile operation which is more efficient.
Examples
--------
join(0, x, x, x) -> tile(x, (3, 1, 1, ...))
join(1, x, x) -> tile(x, (1, 2, 1, ...))
"""
# Extract axis and the tensors being joined
axis
,
*
tensors
=
node
.
inputs
# Optimization only applies when axis is constant
if
not
isinstance
(
axis
,
Constant
):
return
None
# Extract the Python integer from the constant
axis_val
=
axis
.
data
# Need at least 2 tensors to consider optimization
if
len
(
tensors
)
<=
1
:
return
# Check if all tensors are identical
if
not
all
(
t
==
tensors
[
0
]
for
t
in
tensors
[
1
:]):
return
n_reps
=
len
(
tensors
)
first_tensor
=
tensors
[
0
]
ndim
=
first_tensor
.
ndim
# Build reps tuple to repeat only along the join axis
# For shape (a, b, c) joining at axis 1: reps = (1, n_reps, 1)
# This directly concatenates n_reps copies along axis_val
reps
=
tuple
(
n_reps
if
i
==
axis_val
else
1
for
i
in
range
(
ndim
))
result
=
tile
(
first_tensor
,
reps
)
# Preserve debugging information
copy_stack_trace
(
node
.
outputs
[
0
],
result
)
return
[
result
]
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@register_useless
@register_useless
...
...
tests/tensor/rewriting/test_basic.py
浏览文件 @
0b731c27
...
@@ -1237,33 +1237,98 @@ def test_local_join_1():
...
@@ -1237,33 +1237,98 @@ def test_local_join_1():
assert
len
([
n
for
n
in
e
if
isinstance
(
n
.
op
,
Join
)])
==
0
assert
len
([
n
for
n
in
e
if
isinstance
(
n
.
op
,
Join
)])
==
0
assert
f
.
maker
.
fgraph
.
outputs
[
0
]
.
dtype
==
config
.
floatX
assert
f
.
maker
.
fgraph
.
outputs
[
0
]
.
dtype
==
config
.
floatX
#
test we don't apply when their is 2 inputs
#
Test that join with 2 different inputs remains (not optimized away)
s
=
join
(
1
,
a
,
a
)
s
=
join
(
1
,
a
,
a
[:,
::
-
1
]
)
f
=
function
([
a
],
s
,
mode
=
rewrite_mode
)
f
=
function
([
a
],
s
,
mode
=
rewrite_mode
)
val
=
f
([[
1
]])
val
=
f
([[
1
,
2
]])
assert
np
.
all
(
val
==
[[
1
]])
assert
np
.
all
(
val
==
[[
1
,
2
,
2
,
1
]])
# joined along axis 1
e
=
f
.
maker
.
fgraph
.
toposort
()
e
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
([
n
for
n
in
e
if
isinstance
(
n
.
op
,
Join
)])
==
1
assert
len
([
n
for
n
in
e
if
isinstance
(
n
.
op
,
Join
)])
==
1
# join remains
assert
f
.
maker
.
fgraph
.
outputs
[
0
]
.
dtype
==
config
.
floatX
assert
f
.
maker
.
fgraph
.
outputs
[
0
]
.
dtype
==
config
.
floatX
def
test_local_join_to_tile
():
"""Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k.
This optimization applies whenever we concatenate the *same* tensor multiple
times along a given axis. It replaces the Join/concatenate with a Tile op.
"""
# ---- Case 1: joining same vector along axis 0 ----
x
=
vector
(
"x"
)
s
=
join
(
0
,
x
,
x
,
x
)
# (3n,)
f
=
function
([
x
],
s
,
mode
=
rewrite_mode
)
test_val
=
np
.
array
([
1.0
,
2.0
],
dtype
=
config
.
floatX
)
result
=
f
(
test_val
)
expected
=
np
.
array
([
1.0
,
2.0
,
1.0
,
2.0
,
1.0
,
2.0
],
dtype
=
config
.
floatX
)
assert
np
.
allclose
(
result
,
expected
)
# Join should be optimized away
ops
=
f
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Join
)
for
n
in
ops
)
# ---- Case 2: joining same matrix along axis 0 ----
a
=
matrix
(
"a"
)
s
=
join
(
0
,
a
,
a
)
# (2m, n)
f
=
function
([
a
],
s
,
mode
=
rewrite_mode
)
test_mat
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
result
=
f
(
test_mat
)
expected
=
np
.
vstack
([
test_mat
,
test_mat
])
assert
np
.
allclose
(
result
,
expected
)
ops
=
f
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Join
)
for
n
in
ops
)
# ---- Case 3: joining same matrix along axis 1 ----
s
=
join
(
1
,
a
,
a
,
a
)
# (m, 3n)
f
=
function
([
a
],
s
,
mode
=
rewrite_mode
)
result
=
f
(
test_mat
)
expected
=
np
.
hstack
([
test_mat
,
test_mat
,
test_mat
])
assert
np
.
allclose
(
result
,
expected
)
ops
=
f
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Join
)
for
n
in
ops
)
# ---- Case 4: different tensors -> should NOT optimize ----
y
=
vector
(
"y"
)
s
=
join
(
0
,
x
,
y
)
# inputs differ
f
=
function
([
x
,
y
],
s
,
mode
=
rewrite_mode
)
test_vec1
=
np
.
array
([
1.0
,
2.0
],
dtype
=
config
.
floatX
)
test_vec2
=
np
.
array
([
3.0
,
4.0
],
dtype
=
config
.
floatX
)
result
=
f
(
test_vec1
,
test_vec2
)
expected
=
np
.
array
([
1.0
,
2.0
,
3.0
,
4.0
],
dtype
=
config
.
floatX
)
assert
np
.
allclose
(
result
,
expected
)
# Join should still be present since inputs aren't identical
ops
=
f
.
maker
.
fgraph
.
toposort
()
assert
any
(
isinstance
(
n
.
op
,
Join
)
for
n
in
ops
)
def
test_local_join_empty
():
def
test_local_join_empty
():
# Vector case
# Vector case
- empty tensors should be removed
empty_vec
=
np
.
asarray
([],
dtype
=
config
.
floatX
)
empty_vec
=
np
.
asarray
([],
dtype
=
config
.
floatX
)
vec
=
vector
(
"vec"
)
vec
=
vector
(
"vec"
)
s
=
pt
.
join
(
0
,
vec
,
vec
,
empty_vec
)
s
=
pt
.
join
(
0
,
vec
,
vec
[::
-
1
]
,
empty_vec
)
new_s
=
rewrite_graph
(
s
)
new_s
=
rewrite_graph
(
s
)
assert
equal_computations
([
new_s
],
[
join
(
0
,
vec
,
vec
)])
assert
new_s
.
dtype
==
s
.
dtype
assert
new_s
.
dtype
==
s
.
dtype
# Verify that empty tensors are removed from the join
expected
=
pt
.
join
(
0
,
vec
,
vec
[::
-
1
])
assert
equal_computations
([
new_s
],
[
expected
])
# Matrix case
# Matrix case
- empty tensors should be removed
empty_mat
=
np
.
zeros
((
2
,
0
),
dtype
=
config
.
floatX
)
empty_mat
=
np
.
zeros
((
2
,
0
),
dtype
=
config
.
floatX
)
empty_sym_mat
=
matrix
(
"m"
,
shape
=
(
2
,
0
))
empty_sym_mat
=
matrix
(
"m"
,
shape
=
(
2
,
0
))
mat
=
matrix
(
"mat"
,
shape
=
(
2
,
10
))
mat
=
matrix
(
"mat"
,
shape
=
(
2
,
10
))
s
=
join
(
1
,
empty_mat
,
mat
,
empty_sym_mat
,
mat
,
mat
)
s
=
join
(
1
,
empty_mat
,
mat
,
empty_sym_mat
,
mat
[:,
::
-
1
]
)
new_s
=
rewrite_graph
(
s
)
new_s
=
rewrite_graph
(
s
)
assert
equal_computations
([
new_s
],
[
join
(
1
,
mat
,
mat
,
mat
)])
assert
new_s
.
dtype
==
s
.
dtype
assert
new_s
.
dtype
==
s
.
dtype
# Verify that empty tensors are removed from the join
expected
=
join
(
1
,
mat
,
mat
[:,
::
-
1
])
assert
equal_computations
([
new_s
],
[
expected
])
# Join can be completely removed, but casting and specify_shape are propagated
# Join can be completely removed, but casting and specify_shape are propagated
int_mat
=
matrix
(
"int_mat"
,
dtype
=
int
)
int_mat
=
matrix
(
"int_mat"
,
dtype
=
int
)
...
...
tests/tensor/test_basic.py
浏览文件 @
0b731c27
...
@@ -2020,25 +2020,6 @@ class TestJoinAndSplit:
...
@@ -2020,25 +2020,6 @@ class TestJoinAndSplit:
# This line used to crash.
# This line used to crash.
ptb
.
concatenate
([
x
,
-
u
],
axis
=
2
)
ptb
.
concatenate
([
x
,
-
u
],
axis
=
2
)
def
test_concatenate_same
(
self
):
# Test that we can concatenate the same tensor multiple time.
# In the past it was broken on the GPU.
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
())
T_shared
=
self
.
shared
(
rng
.
random
((
3
,
4
))
.
astype
(
self
.
floatX
))
Tout
=
ptb
.
concatenate
([
T_shared
,
T_shared
])
f
=
function
([],
Tout
,
mode
=
self
.
mode
)
out
=
f
()
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
[
True
for
node
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
node
.
op
,
type
(
self
.
join_op
))
]
assert
np
.
allclose
(
out
,
np
.
concatenate
([
T_shared
.
get_value
(),
T_shared
.
get_value
()])
)
def
test_mixed_ndim_error
(
self
):
def
test_mixed_ndim_error
(
self
):
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
())
v
=
self
.
shared
(
rng
.
random
(
4
)
.
astype
(
self
.
floatX
))
v
=
self
.
shared
(
rng
.
random
(
4
)
.
astype
(
self
.
floatX
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论