Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c1ecbe0e
提交
c1ecbe0e
authored
5月 08, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 30, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Avoid large allocation for taps of length 1 in ScanSaveMem
上级
f6958407
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
79 行增加
和
29 行删除
+79
-29
rewriting.py
pytensor/scan/rewriting.py
+24
-15
basic.py
pytensor/tensor/rewriting/basic.py
+4
-2
test_rewriting.py
tests/scan/test_rewriting.py
+51
-12
没有找到文件。
pytensor/scan/rewriting.py
浏览文件 @
c1ecbe0e
...
@@ -53,6 +53,7 @@ from pytensor.scan.utils import (
...
@@ -53,6 +53,7 @@ from pytensor.scan.utils import (
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
Alloc
,
Alloc
,
AllocEmpty
,
AllocEmpty
,
atleast_Nd
,
get_scalar_constant_value
,
get_scalar_constant_value
,
)
)
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
...
@@ -1186,8 +1187,8 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
...
@@ -1186,8 +1187,8 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return
subtensor_merge_replacements
return
subtensor_merge_replacements
def
_is_default_scan_buffer
(
x
:
TensorVariable
)
->
bool
:
def
_is_default_scan_buffer
(
final_buffer
:
TensorVariable
,
taps
:
int
)
->
bool
:
node
=
x
.
owner
node
=
final_buffer
.
owner
if
node
is
None
:
if
node
is
None
:
return
False
return
False
...
@@ -1200,8 +1201,10 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
...
@@ -1200,8 +1201,10 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
):
):
return
False
return
False
x
,
y
,
*
_
=
node
.
inputs
init_buffer
,
init_value
,
*
_
=
node
.
inputs
if
not
(
x
.
owner
is
not
None
and
isinstance
(
x
.
owner
.
op
,
AllocEmpty
)):
if
not
(
init_buffer
.
owner
is
not
None
and
isinstance
(
init_buffer
.
owner
.
op
,
AllocEmpty
)
):
return
False
return
False
# The value may have been broadcast to fill in the initial taps.
# The value may have been broadcast to fill in the initial taps.
...
@@ -1218,10 +1221,16 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
...
@@ -1218,10 +1221,16 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
# But due to laziness we use the slightly more conservative check:
# But due to laziness we use the slightly more conservative check:
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
if
broadcasted_by
(
y
,
x
):
if
taps
>
1
:
return
False
return
not
broadcasted_by
(
init_value
,
init_buffer
)
else
:
return
True
# In this case we know we have alloc_empty(1 + nsteps, ...)[:1].set(init_value)
# The first dimension cannot possibly broadcast in the subtensor assignment,
# so we exclude it from `broadcasted_by`. To exclude it we squeeze it out,
# after adding any other implicit expand_dims. We select into the first entry of
# the buffer, to check for potential broadcasting in other dimensions.
init_value_
=
atleast_Nd
(
init_value
,
n
=
init_buffer
.
ndim
)
return
not
broadcasted_by
(
init_value_
.
squeeze
(
0
),
init_buffer
[
0
])
def
scan_save_mem_rewrite
(
fgraph
,
node
,
backend_supports_output_pre_allocation
:
bool
):
def
scan_save_mem_rewrite
(
fgraph
,
node
,
backend_supports_output_pre_allocation
:
bool
):
...
@@ -1574,15 +1583,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
...
@@ -1574,15 +1583,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# If the memory for this output has been pre-allocated
# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
# before going into the scan op (by an alloc node)
if
idx
<
op_info
.
n_mit_sot
+
op_info
.
n_sit_sot
:
if
idx
<
op_info
.
n_mit_sot
+
op_info
.
n_sit_sot
:
taps
=
init_l
[
i
]
nw_input
=
nw_inputs
[
offset
+
idx
]
nw_input
=
nw_inputs
[
offset
+
idx
]
# Recreate default buffers with new size
# Recreate default buffers with new size
if
_is_default_scan_buffer
(
nw_input
):
if
_is_default_scan_buffer
(
nw_input
,
taps
):
extra_size
=
1
if
required_orphan
else
val
-
init_l
[
i
]
extra_size
=
1
if
required_orphan
else
val
-
taps
nw_input
=
expand_empty
(
nw_input
.
owner
.
inputs
[
1
],
extra_size
)
nw_input
=
expand_empty
(
nw_input
.
owner
.
inputs
[
1
],
extra_size
)
# Otherwise, just trim with a slice
# Otherwise, just trim with a slice
else
:
else
:
stop
=
init_l
[
i
]
if
required_orphan
else
val
stop
=
taps
if
required_orphan
else
val
nw_input
=
nw_input
[:
stop
]
nw_input
=
nw_input
[:
stop
]
nw_inputs
[
offset
+
idx
]
=
nw_input
nw_inputs
[
offset
+
idx
]
=
nw_input
...
@@ -1626,14 +1636,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
...
@@ -1626,14 +1636,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# val == 0 means that we want to keep all intermediate
# val == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
# results for that state, including the initial values.
if
idx
<
op_info
.
n_mit_sot
+
op_info
.
n_sit_sot
:
if
idx
<
op_info
.
n_mit_sot
+
op_info
.
n_sit_sot
:
taps
=
init_l
[
op_info
.
n_mit_mot
+
idx
]
in_idx
=
offset
+
idx
in_idx
=
offset
+
idx
nw_input
=
nw_inputs
[
in_idx
]
nw_input
=
nw_inputs
[
in_idx
]
if
_is_default_scan_buffer
(
nw_input
):
if
_is_default_scan_buffer
(
nw_input
,
taps
):
nw_input
=
expand_empty
(
nw_input
.
owner
.
inputs
[
1
],
nw_steps
)
nw_input
=
expand_empty
(
nw_input
.
owner
.
inputs
[
1
],
nw_steps
)
else
:
else
:
# Number of steps in the initial state
nw_input
=
nw_input
[:
(
taps
+
nw_steps
)]
init_l_pt
=
pt
.
as_tensor
(
init_l
[
op_info
.
n_mit_mot
+
idx
])
nw_input
=
nw_input
[:
(
init_l_pt
+
nw_steps
)]
nw_inputs
[
in_idx
]
=
nw_input
nw_inputs
[
in_idx
]
=
nw_input
elif
(
elif
(
...
...
pytensor/tensor/rewriting/basic.py
浏览文件 @
c1ecbe0e
...
@@ -96,9 +96,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool:
...
@@ -96,9 +96,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool:
"""
"""
bx
=
x
.
type
.
broadcastable
bx
=
x
.
type
.
broadcastable
by
=
y
.
type
.
broadcastable
by
=
y
.
type
.
broadcastable
if
len
(
bx
)
<
len
(
by
):
bx_len
=
len
(
bx
)
by_len
=
len
(
by
)
if
bx_len
<
by_len
:
return
True
return
True
bx
=
bx
[
-
len
(
by
)
:]
bx
=
bx
[
bx_len
-
by_len
:]
return
any
(
bx_dim
and
not
by_dim
for
bx_dim
,
by_dim
in
zip
(
bx
,
by
,
strict
=
True
))
return
any
(
bx_dim
and
not
by_dim
for
bx_dim
,
by_dim
in
zip
(
bx
,
by
,
strict
=
True
))
...
...
tests/scan/test_rewriting.py
浏览文件 @
c1ecbe0e
...
@@ -9,13 +9,14 @@ from pytensor.compile.io import In
...
@@ -9,13 +9,14 @@ from pytensor.compile.io import In
from
pytensor.compile.mode
import
get_default_mode
from
pytensor.compile.mode
import
get_default_mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.gradient
import
grad
,
jacobian
from
pytensor.gradient
import
grad
,
jacobian
from
pytensor.graph.basic
import
Constant
,
equal_computations
from
pytensor.graph.basic
import
Constant
,
ancestors
,
equal_computations
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.replace
import
clone_replace
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
from
pytensor.scan.rewriting
import
ScanInplaceOptimizer
,
ScanMerge
from
pytensor.scan.rewriting
import
ScanInplaceOptimizer
,
ScanMerge
from
pytensor.scan.utils
import
until
from
pytensor.scan.utils
import
until
from
pytensor.tensor
import
stack
from
pytensor.tensor
import
stack
from
pytensor.tensor.basic
import
AllocEmpty
from
pytensor.tensor.blas
import
Dot22
from
pytensor.tensor.blas
import
Dot22
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.math
import
Dot
,
dot
,
sigmoid
,
tanh
from
pytensor.tensor.math
import
Dot
,
dot
,
sigmoid
,
tanh
...
@@ -1207,7 +1208,7 @@ class TestScanInplaceOptimizer:
...
@@ -1207,7 +1208,7 @@ class TestScanInplaceOptimizer:
class
TestSaveMem
:
class
TestSaveMem
:
mode
=
get_default_mode
()
.
including
(
"scan_save_mem"
)
mode
=
get_default_mode
()
.
including
(
"scan_save_mem"
)
.
excluding
(
"scan_pushout"
)
def
test_save_mem
(
self
):
def
test_save_mem
(
self
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
...
@@ -1371,7 +1372,7 @@ class TestSaveMem:
...
@@ -1371,7 +1372,7 @@ class TestSaveMem:
)
)
def
test_save_mem_store_steps
(
self
):
def
test_save_mem_store_steps
(
self
):
def
f_rnn
(
u_t
,
x1_tm1
,
x1_tm3
,
x2_tm1
,
x3tm2
,
x3_tm1
,
x4_tm1
):
def
step
(
u_t
,
x1_tm1
,
x1_tm3
,
x2_tm1
,
x3tm2
,
x3_tm1
,
x4_tm1
):
return
(
return
(
u_t
+
1.0
,
u_t
+
1.0
,
u_t
+
2.0
,
u_t
+
2.0
,
...
@@ -1388,7 +1389,7 @@ class TestSaveMem:
...
@@ -1388,7 +1389,7 @@ class TestSaveMem:
x30
=
vector
(
"x30"
)
x30
=
vector
(
"x30"
)
x40
=
scalar
(
"x40"
)
x40
=
scalar
(
"x40"
)
[
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
],
updates
=
scan
(
[
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
],
updates
=
scan
(
f_rnn
,
step
,
u
,
u
,
[
[
None
,
None
,
...
@@ -1404,7 +1405,7 @@ class TestSaveMem:
...
@@ -1404,7 +1405,7 @@ class TestSaveMem:
go_backwards
=
False
,
go_backwards
=
False
,
)
)
f
2
=
function
(
f
=
function
(
[
u
,
x10
,
x20
,
x30
,
x40
],
[
u
,
x10
,
x20
,
x30
,
x40
],
[
x1
[
-
7
],
x2
[
-
3
:
-
1
],
x3
[
-
6
:],
x4
[
-
1
],
x5
[
-
1
]],
[
x1
[
-
7
],
x2
[
-
3
:
-
1
],
x3
[
-
6
:],
x4
[
-
1
],
x5
[
-
1
]],
updates
=
updates
,
updates
=
updates
,
...
@@ -1417,13 +1418,51 @@ class TestSaveMem:
...
@@ -1417,13 +1418,51 @@ class TestSaveMem:
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
20
,))
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
20
,))
# compute the output in numpy
# compute the output in numpy
tx1
,
tx2
,
tx3
,
tx4
,
tx5
=
f2
(
v_u
,
[
0
,
0
],
0
,
[
0
,
0
],
0
)
tx1
,
tx2
,
tx3
,
tx4
,
tx5
=
f
(
v_u
,
[
0
,
0
],
0
,
[
0
,
0
],
0
)
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-6
utt
.
assert_allclose
(
tx1
,
v_u
[
-
7
]
+
1.0
)
np
.
testing
.
assert_allclose
(
tx1
,
v_u
[
-
7
]
+
1.0
,
rtol
=
rtol
)
utt
.
assert_allclose
(
tx2
,
v_u
[
-
3
:
-
1
]
+
2.0
)
np
.
testing
.
assert_allclose
(
tx2
,
v_u
[
-
3
:
-
1
]
+
2.0
,
rtol
=
rtol
)
utt
.
assert_allclose
(
tx3
,
v_u
[
-
6
:]
+
3.0
)
np
.
testing
.
assert_allclose
(
tx3
,
v_u
[
-
6
:]
+
3.0
,
rtol
=
rtol
)
utt
.
assert_allclose
(
tx4
,
v_u
[
-
1
]
+
4.0
)
np
.
testing
.
assert_allclose
(
tx4
,
v_u
[
-
1
]
+
4.0
,
rtol
=
rtol
)
utt
.
assert_allclose
(
tx5
,
v_u
[
-
1
]
+
5.0
)
np
.
testing
.
assert_allclose
(
tx5
,
v_u
[
-
1
]
+
5.0
,
rtol
=
rtol
)
# Confirm reduction in buffer sizes
[
scan_node
]
=
[
node
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
]
# x6 and x7 are dropped because they are not used
[
n_steps
,
seq
,
x4_buffer
,
x5_buffer
,
x1_len
,
x2_len
,
x3_len
]
=
scan_node
.
inputs
[
x4_underlying_alloc
]
=
[
var
for
var
in
ancestors
([
x4_buffer
])
if
var
.
owner
and
isinstance
(
var
.
owner
.
op
,
AllocEmpty
)
]
[
x5_underlying_alloc
]
=
[
var
for
var
in
ancestors
([
x5_buffer
])
if
var
.
owner
and
isinstance
(
var
.
owner
.
op
,
AllocEmpty
)
]
buffer_lengths
=
pytensor
.
function
(
[
u
,
x10
,
x20
,
x30
,
x40
],
[
x1_len
,
x2_len
,
x3_len
,
x4_underlying_alloc
.
shape
[
0
],
x5_underlying_alloc
.
shape
[
0
],
],
accept_inplace
=
True
,
on_unused_input
=
"ignore"
,
allow_input_downcast
=
True
,
)(
v_u
,
[
0
,
0
],
0
,
[
0
,
0
],
0
)
# ScanSaveMem keeps +1 entries to handle taps with preallocated outputs
assert
[
int
(
i
)
for
i
in
buffer_lengths
]
==
[
7
,
# entry -7 of a map variable is kept, we need at least that many
3
,
# entries [-3, -2] of a map variable are kept, we need at least 3
6
,
# last six entries of a map variable are kept
2
+
1
,
# last entry of a double tap variable is kept
1
+
1
,
# last entry of a single tap variable is kept
]
def
test_savemem_does_not_duplicate_number_of_scan_nodes
(
self
):
def
test_savemem_does_not_duplicate_number_of_scan_nodes
(
self
):
var
=
pt
.
ones
(())
var
=
pt
.
ones
(())
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论