Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ebc0de09
提交
ebc0de09
authored
11月 28, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
12月 05, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Validate compatible linker in Scan make_thunk
上级
7523caa4
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
129 行增加
和
45 行删除
+129
-45
op.py
pytensor/scan/op.py
+38
-44
test_basic.py
tests/scan/test_basic.py
+91
-1
没有找到文件。
pytensor/scan/op.py
浏览文件 @
ebc0de09
...
...
@@ -76,6 +76,7 @@ from pytensor.graph.traversal import graph_inputs
from
pytensor.graph.type
import
HasShape
from
pytensor.graph.utils
import
InconsistencyError
,
MissingInputError
from
pytensor.link.c.basic
import
CLinker
from
pytensor.link.vm
import
VMLinker
from
pytensor.printing
import
op_debug_information
from
pytensor.scan.utils
import
ScanProfileStats
,
Validator
,
forced_replace
,
safe_new
from
pytensor.tensor.basic
import
as_tensor_variable
...
...
@@ -884,16 +885,24 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self
.
nit_sot_arg_offset
=
(
self
.
untraced_sit_sot_arg_offset
+
info
.
n_untraced_sit_sot_outs
)
#
XXX
: This doesn't include `info.n_nit_sot`s, so it's really a count
#
Note
: This doesn't include `info.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs
self
.
n_outs
=
info
.
n_mit_mot
+
info
.
n_mit_sot
+
info
.
n_sit_sot
self
.
n_tap_outs
=
info
.
n_mit_mot
+
info
.
n_mit_sot
# TODO: These can be moved to thunk/function compilation
(
_
,
self
.
mitmots_preallocated
,
)
=
self
.
_mitmot_preallocations
()
# Python and Cython perform methods provide the array location where a mitmot output should be
# stored to the VM as a symbolic update. This helper variable is used in the perform method for validation
mitmots_preallocated
=
[
False
]
*
info
.
n_mit_mot_outs
if
config
.
scan__allow_output_prealloc
:
for
mitmot_idx
in
range
(
info
.
n_mit_mot
):
for
inp_tap
in
info
.
mit_mot_in_slices
[
mitmot_idx
]:
if
inp_tap
in
info
.
mit_mot_out_slices
[
mitmot_idx
]:
# Figure out the index of the corresponding output
output_idx
=
sum
(
len
(
m
)
for
m
in
info
.
mit_mot_out_slices
[:
mitmot_idx
]
)
+
info
.
mit_mot_out_slices
[
mitmot_idx
]
.
index
(
inp_tap
)
mitmots_preallocated
[
output_idx
]
=
True
self
.
mitmots_preallocated
=
tuple
(
mitmots_preallocated
)
self
.
n_outer_inputs
=
info
.
n_outer_inputs
self
.
n_outer_outputs
=
info
.
n_outer_outputs
...
...
@@ -908,39 +917,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
self
.
_hash_inner_graph
=
hash
(
self
.
_cmodule_key
)
def
_mitmot_preallocations
(
self
):
if
config
.
scan__allow_output_prealloc
:
preallocated_mitmot_outs
=
[]
info
=
self
.
info
input_idx
=
info
.
n_seqs
for
mitmot_idx
in
range
(
info
.
n_mit_mot
):
for
inp_tap
in
info
.
mit_mot_in_slices
[
mitmot_idx
]:
if
inp_tap
in
info
.
mit_mot_out_slices
[
mitmot_idx
]:
# Figure out the index of the corresponding output
output_idx
=
sum
(
len
(
m
)
for
m
in
info
.
mit_mot_out_slices
[:
mitmot_idx
]
)
output_idx
+=
info
.
mit_mot_out_slices
[
mitmot_idx
]
.
index
(
inp_tap
)
preallocated_mitmot_outs
.
append
(
output_idx
)
input_idx
+=
1
preallocated_mitmot_outs
.
sort
()
else
:
# Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated
preallocated_mitmot_outs
=
[]
# Store the list of mitmot output taps that have been altered so they
# can be preallocated
mitmots_preallocated
=
[
i
in
preallocated_mitmot_outs
for
i
in
range
(
info
.
n_mit_mot_outs
)
]
return
preallocated_mitmot_outs
,
mitmots_preallocated
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
# Ensure that the graph associated with the inner function is valid.
...
...
@@ -1483,11 +1459,26 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
mode_instance
=
get_mode
(
self
.
mode
)
.
clone
(
link_kwargs
=
dict
(
allow_gc
=
self
.
allow_gc
),
message
=
f
"{self.name or 'Scan'} sub profile"
,
)
mode
=
self
.
mode
if
mode
in
(
None
,
"FAST_RUN"
):
mode_instance
=
Mode
(
"cvm"
,
"fast_run"
)
elif
mode
==
"FAST_COMPILE"
:
mode_instance
=
Mode
(
VMLinker
(
use_cloop
=
False
,
c_thunks
=
False
),
"fast_compile"
)
else
:
mode_instance
=
get_mode
(
mode
)
.
clone
(
link_kwargs
=
dict
(
allow_gc
=
self
.
allow_gc
),
message
=
f
"{self.name or 'Scan'} sub profile"
,
)
# Scan python and cython perform relies on the VM being able to set updates for preallocated MIT-MOT,
# which only the VMs produced by VMLinker do
if
any
(
self
.
mitmots_preallocated
)
and
not
isinstance
(
mode_instance
.
linker
,
VMLinker
):
raise
NotImplementedError
(
f
"Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker, got {mode_instance.linker}"
)
self
.
_fn
=
pfunc
(
wrapped_inputs
,
wrapped_outputs
,
...
...
@@ -2007,6 +1998,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
new_var
=
inner_input_storage
[
inner_inp_idx
]
.
storage
[
0
]
if
old_var
is
new_var
:
old_data
=
old_mitmot_input_data
[
mitmot_inp_idx
]
# This check is only valid if the VM performs updates
# Otherwise the output value may remain the same as the input,
# but doesn't mean that it has been setup correctly
same_data
=
new_var
.
data
==
old_data
else
:
same_data
=
False
...
...
tests/scan/test_basic.py
浏览文件 @
ebc0de09
...
...
@@ -34,10 +34,12 @@ from pytensor.graph.replace import vectorize_graph
from
pytensor.graph.rewriting.basic
import
MergeOptimizer
from
pytensor.graph.traversal
import
ancestors
from
pytensor.graph.utils
import
MissingInputError
from
pytensor.link.vm
import
VMLinker
from
pytensor.raise_op
import
assert_op
from
pytensor.scan.basic
import
scan
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
,
ScanInfo
from
pytensor.scan.utils
import
until
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
dot
,
exp
,
mean
,
sigmoid
,
tanh
from
pytensor.tensor.math
import
sum
as
pt_sum
...
...
@@ -4308,3 +4310,91 @@ def test_return_updates_api_change():
with
pytest
.
raises
(
ValueError
,
match
=
err_msg
):
scan
(
lambda
:
{
x
:
x
+
1
},
outputs_info
=
[],
n_steps
=
5
,
return_updates
=
False
)
@pytest.mark.parametrize
(
"scan_mode"
,
[
None
,
"FAST_RUN"
,
"FAST_COMPILE"
,
Mode
(
"cvm"
,
optimizer
=
None
),
Mode
(
"vm"
,
optimizer
=
None
),
Mode
(
"c"
,
optimizer
=
None
),
Mode
(
"py"
,
optimizer
=
None
),
],
)
def
test_scan_mode_compatibility
(
scan_mode
):
# Regression test for case where using Scan with a non-updating VM failed
# Build a scan with one sequence and two MIT-MOTs
info
=
ScanInfo
(
n_seqs
=
1
,
mit_mot_in_slices
=
((
0
,
1
),
(
0
,
1
)),
mit_mot_out_slices
=
((
1
,),
(
1
,)),
mit_sot_in_slices
=
(),
sit_sot_in_slices
=
(),
n_nit_sot
=
0
,
n_untraced_sit_sot_outs
=
0
,
n_non_seqs
=
0
,
as_while
=
False
,
)
bool_seq
=
pt
.
scalar
(
dtype
=
"bool"
)
mitmot_A0
,
mitmot_A1
,
mitmot_B0
,
mitmot_B1
=
[
pt
.
matrix
(
shape
=
(
2
,
2
))
for
i
in
range
(
4
)
]
inputs
=
[
bool_seq
,
mitmot_A0
,
mitmot_A1
,
mitmot_B0
,
mitmot_B1
,
]
outputs
=
[
pt
.
add
(
bool_seq
+
mitmot_A0
,
mitmot_A1
),
pt
.
add
(
bool_seq
*
mitmot_B0
,
mitmot_B1
),
]
scan_op
=
Scan
(
inputs
,
outputs
,
info
=
info
,
mode
=
scan_mode
,
)
n_steps
=
5
numerical_inputs
=
[
np
.
array
(
n_steps
,
dtype
=
"int64"
),
np
.
array
([
1
,
1
,
0
,
1
,
0
],
dtype
=
"bool"
),
np
.
zeros
(
n_steps
+
1
)[:,
None
,
None
]
*
np
.
eye
(
2
),
np
.
arange
(
n_steps
+
1
)[:,
None
,
None
]
*
np
.
eye
(
2
),
]
tensor_inputs
=
[
as_tensor
(
inp
,
dtype
=
inp
.
dtype
)
.
type
()
for
inp
in
numerical_inputs
]
tensor_outputs
=
[
o
.
sum
()
for
o
in
scan_op
(
*
tensor_inputs
)]
no_opt_mode
=
Mode
(
linker
=
"py"
,
optimizer
=
None
)
# NotImplementedError should only be triggered when we try to compile the function
if
(
# Abstract modes should never fail
scan_mode
not
in
(
None
,
"FAST_RUN"
,
"FAST_COMPILE"
)
# Only if the user tries something specific and incompatible
and
not
isinstance
(
get_mode
(
scan_mode
)
.
linker
,
VMLinker
)
):
with
pytest
.
raises
(
NotImplementedError
,
match
=
"Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker"
,
):
function
(
tensor_inputs
,
tensor_outputs
,
mode
=
no_opt_mode
)
return
fn
=
function
(
tensor_inputs
,
tensor_outputs
,
mode
=
no_opt_mode
)
# Check we have the expected Scan in the compiled function
[
fn_scan_op
]
=
[
node
.
op
for
node
in
fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
]
assert
fn_scan_op
.
info
==
info
assert
fn_scan_op
.
mitmots_preallocated
==
(
True
,
True
)
# Expected value computed by running correct Scan once
np
.
testing
.
assert_allclose
(
fn
(
*
numerical_inputs
),
[
44
,
38
])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论