Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
85235623
提交
85235623
authored
3月 07, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
3月 13, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Do more agressive scan memory saves in JIT backends
上级
92420c8f
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
75 行增加
和
32 行删除
+75
-32
mode.py
pytensor/compile/mode.py
+15
-7
configdefaults.py
pytensor/configdefaults.py
+3
-1
rewriting.py
pytensor/scan/rewriting.py
+50
-15
subtensor.py
pytensor/tensor/subtensor.py
+2
-2
test_scan.py
tests/link/numba/test_scan.py
+2
-3
test_rewriting.py
tests/scan/test_rewriting.py
+3
-4
没有找到文件。
pytensor/compile/mode.py
浏览文件 @
85235623
...
@@ -454,6 +454,19 @@ else:
...
@@ -454,6 +454,19 @@ else:
RewriteDatabaseQuery
(
include
=
[
"fast_run"
,
"py_only"
]),
RewriteDatabaseQuery
(
include
=
[
"fast_run"
,
"py_only"
]),
)
)
NUMBA
=
Mode
(
NumbaLinker
(),
RewriteDatabaseQuery
(
include
=
[
"fast_run"
,
"numba"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
,
"local_careduce_fusion"
,
"scan_save_mem_prealloc"
,
],
),
)
JAX
=
Mode
(
JAX
=
Mode
(
JAXLinker
(),
JAXLinker
(),
RewriteDatabaseQuery
(
RewriteDatabaseQuery
(
...
@@ -463,6 +476,7 @@ JAX = Mode(
...
@@ -463,6 +476,7 @@ JAX = Mode(
"BlasOpt"
,
"BlasOpt"
,
"fusion"
,
"fusion"
,
"inplace"
,
"inplace"
,
"scan_save_mem_prealloc"
,
],
],
),
),
)
)
...
@@ -476,16 +490,10 @@ PYTORCH = Mode(
...
@@ -476,16 +490,10 @@ PYTORCH = Mode(
"fusion"
,
"fusion"
,
"inplace"
,
"inplace"
,
"local_uint_constant_indices"
,
"local_uint_constant_indices"
,
"scan_save_mem_prealloc"
,
],
],
),
),
)
)
NUMBA
=
Mode
(
NumbaLinker
(),
RewriteDatabaseQuery
(
include
=
[
"fast_run"
,
"numba"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
,
"local_careduce_fusion"
],
),
)
predefined_modes
=
{
predefined_modes
=
{
...
...
pytensor/configdefaults.py
浏览文件 @
85235623
...
@@ -1085,7 +1085,9 @@ def add_scan_configvars():
...
@@ -1085,7 +1085,9 @@ def add_scan_configvars():
"scan__allow_output_prealloc"
,
"scan__allow_output_prealloc"
,
"Allow/disallow memory preallocation for outputs inside of scan "
"Allow/disallow memory preallocation for outputs inside of scan "
"(default: True)"
,
"(default: True)"
,
BoolParam
(
True
),
# Non-mutable because ScanSaveMem rewrite checks it,
# and we can't have the rewrite and the implementation mismatch
BoolParam
(
True
,
mutable
=
False
),
in_c_key
=
False
,
in_c_key
=
False
,
)
)
...
...
pytensor/scan/rewriting.py
浏览文件 @
85235623
...
@@ -70,7 +70,7 @@ from pytensor.tensor.subtensor import (
...
@@ -70,7 +70,7 @@ from pytensor.tensor.subtensor import (
get_slice_elements
,
get_slice_elements
,
set_subtensor
,
set_subtensor
,
)
)
from
pytensor.tensor.variable
import
TensorConstant
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
list_opt_slice
=
[
list_opt_slice
=
[
...
@@ -1182,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
...
@@ -1182,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return
subtensor_merge_replacements
return
subtensor_merge_replacements
@node_rewriter
([
Scan
])
def
scan_save_mem_rewrite
(
fgraph
,
node
,
backend_supports_output_pre_allocation
:
bool
):
def
scan_save_mem
(
fgraph
,
node
):
r"""Graph optimizer that reduces scan memory consumption.
r"""Graph optimizer that reduces scan memory consumption.
This optimizations attempts to determine if a `Scan` node, during its execution,
This optimizations attempts to determine if a `Scan` node, during its execution,
...
@@ -1214,10 +1213,16 @@ def scan_save_mem(fgraph, node):
...
@@ -1214,10 +1213,16 @@ def scan_save_mem(fgraph, node):
The scan perform implementation takes the output sizes into consideration,
The scan perform implementation takes the output sizes into consideration,
saving the newest results over the oldest ones whenever the buffer is filled.
saving the newest results over the oldest ones whenever the buffer is filled.
"""
if
not
isinstance
(
node
.
op
,
Scan
):
return
False
Paramaters
----------
backend_supports_output_pre_allocation: bool
When the backend supports output pre-allocation Scan must keep buffers
with a length of required_states + 1, because the inner function will
attempt to write the inner function outputs directly into the provided
position in the outer circular buffer. This would invalidate results,
if the input is still needed for some other output computation.
"""
if
hasattr
(
fgraph
,
"shape_feature"
):
if
hasattr
(
fgraph
,
"shape_feature"
):
shape_of
=
fgraph
.
shape_feature
.
shape_of
shape_of
=
fgraph
.
shape_feature
.
shape_of
else
:
else
:
...
@@ -1270,6 +1275,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1270,6 +1275,7 @@ def scan_save_mem(fgraph, node):
# Note: For simplicity while Scans also have global_nsteps set to None.
# Note: For simplicity while Scans also have global_nsteps set to None.
# All step optimizations require knowing the shape of the output, which
# All step optimizations require knowing the shape of the output, which
# cannot be determined from the inputs alone.
# cannot be determined from the inputs alone.
global_nsteps
:
None
|
dict
assert
len
(
node
.
outputs
)
>=
c_outs
assert
len
(
node
.
outputs
)
>=
c_outs
if
len
(
node
.
outputs
)
==
c_outs
and
not
op
.
info
.
as_while
:
if
len
(
node
.
outputs
)
==
c_outs
and
not
op
.
info
.
as_while
:
global_nsteps
=
{
"real"
:
-
1
,
"sym"
:
[]}
global_nsteps
=
{
"real"
:
-
1
,
"sym"
:
[]}
...
@@ -1277,7 +1283,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1277,7 +1283,7 @@ def scan_save_mem(fgraph, node):
global_nsteps
=
None
global_nsteps
=
None
# Keeps track of the original slices that each client represent
# Keeps track of the original slices that each client represent
slices
=
[
None
for
o
in
node
.
outputs
]
slices
:
list
[
None
|
list
]
=
[
None
for
o
in
node
.
outputs
]
# A list for each output indicating how many intermediate values
# A list for each output indicating how many intermediate values
# should be stored. If negative it means none of the intermediate
# should be stored. If negative it means none of the intermediate
...
@@ -1294,7 +1300,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1294,7 +1300,7 @@ def scan_save_mem(fgraph, node):
# or not
# or not
flag_store
=
False
flag_store
=
False
# 2.2 Loop over the clients
# 2.2 Loop over the clients
to figure out how many steps we actually need to do in the Scan
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
for
i
,
out
in
enumerate
(
node
.
outputs
[:
c_outs
]):
# look at all its clients
# look at all its clients
slices
[
i
]
=
[]
slices
[
i
]
=
[]
...
@@ -1337,7 +1343,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1337,7 +1343,7 @@ def scan_save_mem(fgraph, node):
except
KeyError
:
except
KeyError
:
length
=
out
.
shape
[
0
]
length
=
out
.
shape
[
0
]
cf_slice
=
get_canonical_form_slice
(
this_slice
[
0
],
length
)
cf_slice
=
get_canonical_form_slice
(
this_slice
[
0
],
length
)
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
slices
[
i
]
+=
[(
cf_slice
,
this_slice
)]
# type: ignore
if
isinstance
(
this_slice
[
0
],
slice
)
and
this_slice
[
0
]
.
stop
is
None
:
if
isinstance
(
this_slice
[
0
],
slice
)
and
this_slice
[
0
]
.
stop
is
None
:
global_nsteps
=
None
global_nsteps
=
None
...
@@ -1477,7 +1483,10 @@ def scan_save_mem(fgraph, node):
...
@@ -1477,7 +1483,10 @@ def scan_save_mem(fgraph, node):
# for mitsots and sitsots (because mitmots are not
# for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if
# currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated.
# the pre-allocation mechanism is activated.
prealloc_outs
=
config
.
scan__allow_output_prealloc
prealloc_outs
=
(
backend_supports_output_pre_allocation
and
config
.
scan__allow_output_prealloc
)
first_mitsot_idx
=
op_info
.
n_mit_mot
first_mitsot_idx
=
op_info
.
n_mit_mot
last_sitsot_idx
=
(
last_sitsot_idx
=
(
...
@@ -1486,6 +1495,8 @@ def scan_save_mem(fgraph, node):
...
@@ -1486,6 +1495,8 @@ def scan_save_mem(fgraph, node):
preallocable_output
=
first_mitsot_idx
<=
i
<=
last_sitsot_idx
preallocable_output
=
first_mitsot_idx
<=
i
<=
last_sitsot_idx
if
prealloc_outs
and
preallocable_output
:
if
prealloc_outs
and
preallocable_output
:
# TODO: If there's only one output or other outputs do not depend
# on the same input, we could reduce the buffer size to the minimum
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
]
+
1
)
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
]
+
1
)
else
:
else
:
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
])
pval
=
select_max
(
nw_steps
-
start
+
init_l
[
i
],
init_l
[
i
])
...
@@ -1652,7 +1663,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1652,7 +1663,7 @@ def scan_save_mem(fgraph, node):
name
=
op
.
name
,
name
=
op
.
name
,
allow_gc
=
op
.
allow_gc
,
allow_gc
=
op
.
allow_gc
,
)
)
new_outs
=
new_op
(
*
node_ins
,
return_list
=
True
)
new_outs
=
cast
(
list
[
TensorVariable
],
new_op
(
*
node_ins
,
return_list
=
True
)
)
old_new
=
[]
old_new
=
[]
# 3.7 Get replace pairs for those outputs that do not change
# 3.7 Get replace pairs for those outputs that do not change
...
@@ -1682,7 +1693,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1682,7 +1693,7 @@ def scan_save_mem(fgraph, node):
sl_ins
=
get_slice_elements
(
sl_ins
=
get_slice_elements
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
)
)
new_o
=
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
)
new_o
=
cast
(
TensorVariable
,
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
)
)
if
new_o
.
ndim
>
0
:
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
new_o
=
new_o
[::
cnf_slice
[
1
]]
replaced_outs
.
append
(
idx
)
replaced_outs
.
append
(
idx
)
...
@@ -1737,7 +1748,7 @@ def scan_save_mem(fgraph, node):
...
@@ -1737,7 +1748,7 @@ def scan_save_mem(fgraph, node):
sl_ins
=
get_slice_elements
(
sl_ins
=
get_slice_elements
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
)
)
new_o
=
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
)
new_o
=
cast
(
TensorVariable
,
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
)
)
if
new_o
.
ndim
>
0
:
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
new_o
=
new_o
[::
cnf_slice
[
1
]]
old_new
+=
[(
old
,
new_o
)]
old_new
+=
[(
old
,
new_o
)]
...
@@ -1768,6 +1779,20 @@ def scan_save_mem(fgraph, node):
...
@@ -1768,6 +1779,20 @@ def scan_save_mem(fgraph, node):
return
False
return
False
@node_rewriter
([
Scan
])
def
scan_save_mem_prealloc
(
fgraph
,
node
):
return
scan_save_mem_rewrite
(
fgraph
,
node
,
backend_supports_output_pre_allocation
=
True
)
@node_rewriter
([
Scan
])
def
scan_save_mem_no_prealloc
(
fgraph
,
node
):
return
scan_save_mem_rewrite
(
fgraph
,
node
,
backend_supports_output_pre_allocation
=
False
)
class
ScanMerge
(
GraphRewriter
):
class
ScanMerge
(
GraphRewriter
):
r"""Graph optimizer that merges different scan ops.
r"""Graph optimizer that merges different scan ops.
...
@@ -2495,10 +2520,20 @@ optdb.register("scan_eqopt1", scan_eqopt1, "fast_run", "scan", position=0.05)
...
@@ -2495,10 +2520,20 @@ optdb.register("scan_eqopt1", scan_eqopt1, "fast_run", "scan", position=0.05)
optdb
.
register
(
"scan_eqopt2"
,
scan_eqopt2
,
"fast_run"
,
"scan"
,
position
=
1.6
)
optdb
.
register
(
"scan_eqopt2"
,
scan_eqopt2
,
"fast_run"
,
"scan"
,
position
=
1.6
)
# ScanSaveMem should execute only once per node.
# ScanSaveMem should execute only once per node.
optdb
.
register
(
optdb
.
register
(
"scan_save_mem"
,
"scan_save_mem
_prealloc
"
,
in2out
(
scan_save_mem
,
ignore_newtrees
=
True
),
in2out
(
scan_save_mem
_prealloc
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
"scan_save_mem"
,
position
=
1.61
,
)
optdb
.
register
(
"scan_save_mem_no_prealloc"
,
in2out
(
scan_save_mem_no_prealloc
,
ignore_newtrees
=
True
),
"numba"
,
"jax"
,
"pytorch"
,
use_db_name_as_tag
=
False
,
position
=
1.61
,
position
=
1.61
,
)
)
optdb
.
register
(
optdb
.
register
(
...
...
pytensor/tensor/subtensor.py
浏览文件 @
85235623
import
logging
import
logging
import
sys
import
sys
import
warnings
import
warnings
from
collections.abc
import
Callable
,
Iterable
from
collections.abc
import
Callable
,
Iterable
,
Sequence
from
itertools
import
chain
,
groupby
from
itertools
import
chain
,
groupby
from
textwrap
import
dedent
from
textwrap
import
dedent
from
typing
import
cast
,
overload
from
typing
import
cast
,
overload
...
@@ -645,7 +645,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
...
@@ -645,7 +645,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
def
get_slice_elements
(
def
get_slice_elements
(
idxs
:
list
,
idxs
:
Sequence
,
cond
:
Callable
=
lambda
x
:
isinstance
(
x
,
Variable
),
cond
:
Callable
=
lambda
x
:
isinstance
(
x
,
Variable
),
)
->
list
:
)
->
list
:
"""Extract slice elements conditional on a given predicate function.
"""Extract slice elements conditional on a given predicate function.
...
...
tests/link/numba/test_scan.py
浏览文件 @
85235623
...
@@ -465,7 +465,7 @@ class TestScanSITSOTBuffer:
...
@@ -465,7 +465,7 @@ class TestScanSITSOTBuffer:
)
)
if
buffer_size
==
"unit"
:
if
buffer_size
==
"unit"
:
xs_kept
=
xs
[
-
1
]
# Only last state is used
xs_kept
=
xs
[
-
1
]
# Only last state is used
expected_buffer_size
=
2
expected_buffer_size
=
1
elif
buffer_size
==
"aligned"
:
elif
buffer_size
==
"aligned"
:
xs_kept
=
xs
[
-
2
:]
# The buffer will be aligned at the end of the 9 steps
xs_kept
=
xs
[
-
2
:]
# The buffer will be aligned at the end of the 9 steps
expected_buffer_size
=
2
expected_buffer_size
=
2
...
@@ -555,8 +555,7 @@ class TestScanMITSOTBuffer:
...
@@ -555,8 +555,7 @@ class TestScanMITSOTBuffer:
accept_inplace
=
True
,
accept_inplace
=
True
,
on_unused_input
=
"ignore"
,
on_unused_input
=
"ignore"
,
)
)
assert
tuple
(
mitsot_buffer_shape
)
==
(
3
,)
assert
tuple
(
mitsot_buffer_shape
)
==
(
2
,)
if
benchmark
is
not
None
:
if
benchmark
is
not
None
:
numba_fn
.
trust_input
=
True
numba_fn
.
trust_input
=
True
benchmark
(
numba_fn
,
*
test_vals
)
benchmark
(
numba_fn
,
*
test_vals
)
...
...
tests/scan/test_rewriting.py
浏览文件 @
85235623
...
@@ -742,7 +742,7 @@ class TestPushOutAddScan:
...
@@ -742,7 +742,7 @@ class TestPushOutAddScan:
utt
.
assert_allclose
(
f_opt_output
,
f_no_opt_output
)
utt
.
assert_allclose
(
f_opt_output
,
f_no_opt_output
)
def
test_non_zero_init
(
self
):
def
test_non_zero_init
(
self
):
"""Test the case where the initial value for the
n
itsot output is non-zero."""
"""Test the case where the initial value for the
s
itsot output is non-zero."""
input1
=
tensor3
()
input1
=
tensor3
()
input2
=
tensor3
()
input2
=
tensor3
()
...
@@ -759,8 +759,7 @@ class TestPushOutAddScan:
...
@@ -759,8 +759,7 @@ class TestPushOutAddScan:
init
=
pt
.
as_tensor_variable
(
np
.
random
.
normal
(
size
=
(
3
,
7
)))
init
=
pt
.
as_tensor_variable
(
np
.
random
.
normal
(
size
=
(
3
,
7
)))
# Compile the function twice, once with the optimization and once
# Compile the function twice, once with the optimization and once without
# without
opt_mode
=
mode
.
including
(
"scan"
)
opt_mode
=
mode
.
including
(
"scan"
)
h
,
_
=
pytensor
.
scan
(
h
,
_
=
pytensor
.
scan
(
inner_fct
,
inner_fct
,
...
@@ -792,7 +791,7 @@ class TestPushOutAddScan:
...
@@ -792,7 +791,7 @@ class TestPushOutAddScan:
output_opt
=
f_opt
(
input1_value
,
input2_value
,
input3_value
)
output_opt
=
f_opt
(
input1_value
,
input2_value
,
input3_value
)
output_no_opt
=
f_no_opt
(
input1_value
,
input2_value
,
input3_value
)
output_no_opt
=
f_no_opt
(
input1_value
,
input2_value
,
input3_value
)
utt
.
assert_allclose
(
output_opt
,
output_no_opt
)
np
.
testing
.
assert_allclose
(
output_opt
,
output_no_opt
)
class
TestScanMerge
:
class
TestScanMerge
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论