Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2ce0ce13
提交
2ce0ce13
authored
9月 02, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
9月 20, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Note failing scan rewrite
上级
9a124cac
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
24 行增加
和
5 行删除
+24
-5
rewriting.py
pytensor/scan/rewriting.py
+13
-4
test_rewriting.py
tests/scan/test_rewriting.py
+11
-1
没有找到文件。
pytensor/scan/rewriting.py
浏览文件 @
2ce0ce13
...
@@ -658,10 +658,9 @@ def inner_sitsot_only_last_step_used(
...
@@ -658,10 +658,9 @@ def inner_sitsot_only_last_step_used(
fgraph
:
FunctionGraph
,
var
:
Variable
,
scan_args
:
ScanArgs
fgraph
:
FunctionGraph
,
var
:
Variable
,
scan_args
:
ScanArgs
)
->
bool
:
)
->
bool
:
"""
"""
Given a inner nit-sot output of `Scan`, return ``True`` iff the outer
Given a inner sit-sot output of `Scan`, return ``True`` iff the outer
nit-sot output has only one client and that client is a `Subtensor`
sit-sot output has only one client and that client is a `Subtensor`
instance that takes only the last step (last element along the first
instance that takes only the last step (last element along the first axis).
axis).
"""
"""
idx
=
scan_args
.
inner_out_sit_sot
.
index
(
var
)
idx
=
scan_args
.
inner_out_sit_sot
.
index
(
var
)
outer_var
=
scan_args
.
outer_out_sit_sot
[
idx
]
outer_var
=
scan_args
.
outer_out_sit_sot
[
idx
]
...
@@ -832,6 +831,14 @@ def scan_push_out_add(fgraph, node):
...
@@ -832,6 +831,14 @@ def scan_push_out_add(fgraph, node):
Like `scan_push_out_seq`, this optimization aims to replace many operations
Like `scan_push_out_seq`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
increased memory usage.
FIXME: This rewrite doesn't cover user defined graphs,
since it doesn't account for the intermediate slice
returned by the scan constructor for sit-sot (i.e., something like output[1:]).
It only looks for `outputs[-1]` but the user will only ever write `outputs[1:][-1]`
The relevant helper function is `inner_sitsot_only_last_step_used` which is only used by this rewrite
Note this rewrite is registered before subtensor_merge, but even if it were after subtensor_merge is a mess
and doesn't simplify to x[1:][-1] to x[-1] unless x length is statically known
"""
"""
# Don't perform the optimization on `as_while` `Scan`s. Because these
# Don't perform the optimization on `as_while` `Scan`s. Because these
# `Scan`s don't run for a predetermined number of steps, handling them is
# `Scan`s don't run for a predetermined number of steps, handling them is
...
@@ -857,6 +864,7 @@ def scan_push_out_add(fgraph, node):
...
@@ -857,6 +864,7 @@ def scan_push_out_add(fgraph, node):
isinstance
(
nd
.
op
,
Elemwise
)
isinstance
(
nd
.
op
,
Elemwise
)
and
isinstance
(
nd
.
op
.
scalar_op
,
ps
.
Add
)
and
isinstance
(
nd
.
op
.
scalar_op
,
ps
.
Add
)
and
nd
.
out
in
args
.
inner_out_sit_sot
and
nd
.
out
in
args
.
inner_out_sit_sot
# FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern
and
inner_sitsot_only_last_step_used
(
fgraph
,
nd
.
out
,
args
)
and
inner_sitsot_only_last_step_used
(
fgraph
,
nd
.
out
,
args
)
):
):
# Ensure that one of the input to the add is the output of
# Ensure that one of the input to the add is the output of
...
@@ -920,6 +928,7 @@ def scan_push_out_add(fgraph, node):
...
@@ -920,6 +928,7 @@ def scan_push_out_add(fgraph, node):
# external Dot instead of the output of scan
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
# Modify the outer graph to add the outer Dot
outer_sitsot
=
new_scan_args
.
outer_out_sit_sot
[
sitsot_idx
]
outer_sitsot
=
new_scan_args
.
outer_out_sit_sot
[
sitsot_idx
]
# TODO: If we fix the FIXME above, we have to make sure we replace the last subtensor, not the immediate one
subtensor_node
=
fgraph
.
clients
[
outer_sitsot
][
0
][
0
]
subtensor_node
=
fgraph
.
clients
[
outer_sitsot
][
0
][
0
]
outer_sitsot_last_step
=
subtensor_node
.
outputs
[
0
]
outer_sitsot_last_step
=
subtensor_node
.
outputs
[
0
]
...
...
tests/scan/test_rewriting.py
浏览文件 @
2ce0ce13
...
@@ -600,10 +600,12 @@ class TestPushOutAddScan:
...
@@ -600,10 +600,12 @@ class TestPushOutAddScan:
is used to compute the sum over the dot products between the corresponding
is used to compute the sum over the dot products between the corresponding
elements of two list of matrices.
elements of two list of matrices.
TODO FIXME XXX
: These aren't real tests; they simply confirm that a few
FIXME
: These aren't real tests; they simply confirm that a few
graph that could be relevant to the push-out optimizations can be compiled
graph that could be relevant to the push-out optimizations can be compiled
and evaluated. None of them confirm that a push-out optimization has been
and evaluated. None of them confirm that a push-out optimization has been
performed.
performed.
FIXME: The rewrite is indeed broken, probably fro a long while, see FIXME details in the respective rewrite
"""
"""
def
test_sum_dot
(
self
):
def
test_sum_dot
(
self
):
...
@@ -614,7 +616,15 @@ class TestPushOutAddScan:
...
@@ -614,7 +616,15 @@ class TestPushOutAddScan:
sequences
=
[
A
.
dimshuffle
(
0
,
1
,
"x"
),
B
.
dimshuffle
(
0
,
"x"
,
1
)],
sequences
=
[
A
.
dimshuffle
(
0
,
1
,
"x"
),
B
.
dimshuffle
(
0
,
"x"
,
1
)],
outputs_info
=
[
pt
.
zeros_like
(
A
)],
outputs_info
=
[
pt
.
zeros_like
(
A
)],
)
)
# FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that.
# They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]`
# instead of `scan_out[1:][-1]` that the user would define by writing `s[-1]`
# It however, tests the only case the rewrite supports now
f
=
function
([
A
,
B
],
S
.
owner
.
inputs
[
0
][
-
1
])
f
=
function
([
A
,
B
],
S
.
owner
.
inputs
[
0
][
-
1
])
has_scan
=
any
(
isinstance
(
node
.
op
,
Scan
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
# Rewrite is only triggered in fast_run mode
assert
has_scan
if
(
config
.
mode
==
"FAST_COMPILE"
)
else
(
not
has_scan
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
vA
=
rng
.
uniform
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
vA
=
rng
.
uniform
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
vB
=
rng
.
uniform
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
vB
=
rng
.
uniform
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论