Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4948903d
提交
4948903d
authored
5月 31, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 31, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Harmonize Scan rewrite and tag names
上级
20b6a20c
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
24 行增加
和
20 行删除
+24
-20
scan.py
pytensor/link/numba/dispatch/scan.py
+1
-1
rewriting.py
pytensor/scan/rewriting.py
+20
-16
test_rewriting.py
tests/scan/test_rewriting.py
+3
-3
没有找到文件。
pytensor/link/numba/dispatch/scan.py
浏览文件 @
4948903d
...
...
@@ -184,7 +184,7 @@ def numba_funcify_Scan(op, node, **kwargs):
# rotation for initially truncated storage.
output_storage_post_proc_stmts
:
list
[
str
]
=
[]
# In truncated storage situations (e.g. created by `s
ave_mem_new_scan
`),
# In truncated storage situations (e.g. created by `s
can_save_mem
`),
# the taps and output storage overlap, instead of the standard situation in
# which the output storage is large enough to contain both the initial taps
# values and the output storage. In this truncated case, we use the
...
...
pytensor/scan/rewriting.py
浏览文件 @
4948903d
...
...
@@ -209,7 +209,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
@node_rewriter
([
Scan
])
def
push_out_non_seq_scan
(
fgraph
,
node
):
def
scan_push_out_non_seq
(
fgraph
,
node
):
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
This optimizations pushes, out of `Scan`'s inner function and into the outer
...
...
@@ -417,10 +417,10 @@ def push_out_non_seq_scan(fgraph, node):
@node_rewriter
([
Scan
])
def
push_out_seq_scan
(
fgraph
,
node
):
def
scan_push_out_seq
(
fgraph
,
node
):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
This optimization resembles `
push_out_non_seq_scan
` but it tries to push--out of
This optimization resembles `
scan_push_out_non_seq
` but it tries to push--out of
the inner function--the computation that only relies on sequence and
non-sequence inputs. The idea behind this optimization is that, when it is
possible to do so, it is generally more computationally efficient to perform
...
...
@@ -822,10 +822,10 @@ def add_nitsot_outputs(
@node_rewriter
([
Scan
])
def
push_out_add_scan
(
fgraph
,
node
):
def
scan_push_out_add
(
fgraph
,
node
):
r"""Push `Add` operations performed at the end of the inner graph to the outside.
Like `
push_out_seq_scan
`, this optimization aims to replace many operations
Like `
scan_push_out_seq
`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
"""
...
...
@@ -1185,7 +1185,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
@node_rewriter
([
Scan
])
def
s
ave_mem_new_scan
(
fgraph
,
node
):
def
s
can_save_mem
(
fgraph
,
node
):
r"""Graph optimizer that reduces scan memory consumption.
This optimizations attempts to determine if a `Scan` node, during its execution,
...
...
@@ -2282,7 +2282,7 @@ def scan_merge_inouts(fgraph, node):
@node_rewriter
([
Scan
])
def
push_out_dot1_scan
(
fgraph
,
node
):
def
scan_push_out_dot1
(
fgraph
,
node
):
r"""
This is another optimization that attempts to detect certain patterns of
computation in a `Scan` `Op`'s inner function and move this computation to the
...
...
@@ -2483,7 +2483,7 @@ optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node.
optdb
.
register
(
"scan_save_mem"
,
in2out
(
s
ave_mem_new_scan
,
ignore_newtrees
=
True
),
in2out
(
s
can_save_mem
,
ignore_newtrees
=
True
),
"fast_run"
,
"scan"
,
position
=
1.61
,
...
...
@@ -2511,8 +2511,9 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
"scan_pushout_nonseqs_ops"
,
in2out
(
push_out_non_seq_scan
,
ignore_newtrees
=
True
),
"scan_push_out_non_seq"
,
in2out
(
scan_push_out_non_seq
,
ignore_newtrees
=
True
),
"scan_pushout_nonseqs_ops"
,
# For backcompat: so it can be tagged with old name
"fast_run"
,
"scan"
,
"scan_pushout"
,
...
...
@@ -2521,8 +2522,9 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
"scan_pushout_seqs_ops"
,
in2out
(
push_out_seq_scan
,
ignore_newtrees
=
True
),
"scan_push_out_seq"
,
in2out
(
scan_push_out_seq
,
ignore_newtrees
=
True
),
"scan_pushout_seqs_ops"
,
# For backcompat: so it can be tagged with old name
"fast_run"
,
"scan"
,
"scan_pushout"
,
...
...
@@ -2531,8 +2533,9 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
"scan_pushout_dot1"
,
in2out
(
push_out_dot1_scan
,
ignore_newtrees
=
True
),
"scan_push_out_dot1"
,
in2out
(
scan_push_out_dot1
,
ignore_newtrees
=
True
),
"scan_pushout_dot1"
,
# For backcompat: so it can be tagged with old name
"fast_run"
,
"more_mem"
,
"scan"
,
...
...
@@ -2542,9 +2545,10 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
"scan_pushout_add"
,
"scan_push
_
out_add"
,
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
in2out
(
push_out_add_scan
,
ignore_newtrees
=
False
),
in2out
(
scan_push_out_add
,
ignore_newtrees
=
False
),
"scan_pushout_add"
,
# For backcompat: so it can be tagged with old name
"fast_run"
,
"more_mem"
,
"scan"
,
...
...
tests/scan/test_rewriting.py
浏览文件 @
4948903d
...
...
@@ -304,7 +304,7 @@ class TestPushOutDot:
class
TestPushOutNonSeqScan
:
"""
Tests for the `
push_out_non_seq_scan
` optimization in the case where the inner
Tests for the `
scan_push_out_non_seq
` optimization in the case where the inner
function of a `Scan` `Op` has an output which is the result of a `Dot` product
on a non-sequence matrix input to `Scan` and a vector that is the result of
computation in the inner function.
...
...
@@ -595,7 +595,7 @@ class TestPushOutNonSeqScan:
class
TestPushOutAddScan
:
"""
Test case for the `
push_out_add_scan
` optimization in the case where the `Scan`
Test case for the `
scan_push_out_add
` optimization in the case where the `Scan`
is used to compute the sum over the dot products between the corresponding
elements of two list of matrices.
...
...
@@ -1208,7 +1208,7 @@ class TestScanInplaceOptimizer:
class
TestSaveMem
:
mode
=
get_default_mode
()
.
including
(
"scan_save_mem"
,
"s
ave_mem_new_scan
"
)
mode
=
get_default_mode
()
.
including
(
"scan_save_mem"
,
"s
can_save_mem
"
)
def
test_save_mem
(
self
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论