Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
74fb5433
提交
74fb5433
authored
6月 13, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
6月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Prevent unnecessary Scan inplace rewrites
上级
a3dc0a72
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
89 行增加
和
49 行删除
+89
-49
opt.py
aesara/scan/opt.py
+43
-47
test_opt.py
tests/scan/test_opt.py
+46
-2
没有找到文件。
aesara/scan/opt.py
浏览文件 @
74fb5433
...
@@ -4,7 +4,7 @@ import copy
...
@@ -4,7 +4,7 @@ import copy
import
dataclasses
import
dataclasses
from
itertools
import
chain
from
itertools
import
chain
from
sys
import
maxsize
from
sys
import
maxsize
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
cast
import
numpy
as
np
import
numpy
as
np
...
@@ -928,32 +928,32 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -928,32 +928,32 @@ class ScanInplaceOptimizer(GlobalOptimizer):
"""
"""
def
__init__
(
self
,
typeInfer
=
None
):
alloc_ops
=
(
Alloc
,
AllocEmpty
)
super
()
.
__init__
()
"""
self
.
typeInfer
=
typeInfer
Classes that represent operation that allocate new memory and that the
optimization should duplicate so it can operate inplace on them.
"""
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ReplaceValidate
())
fgraph
.
attach_feature
(
ReplaceValidate
())
fgraph
.
attach_feature
(
DestroyHandler
())
fgraph
.
attach_feature
(
DestroyHandler
())
def
attempt_scan_inplace
(
self
,
fgraph
,
node
,
output_indices
,
alloc_ops
):
def
attempt_scan_inplace
(
self
,
fgraph
:
FunctionGraph
,
node
:
Apply
,
output_indices
:
List
[
int
]
)
->
Optional
[
Apply
]:
"""Attempt to replace a `Scan` node by one which computes the specified outputs inplace.
"""Attempt to replace a `Scan` node by one which computes the specified outputs inplace.
Parameters
Parameters
----------
----------
fgraph
: FunctionGraph
fgraph
Function graph in which to attempt the replacement
Function graph in which to attempt the replacement
node
: Apply node
node
Scan node to replace by an inplace version
Scan node to replace by an inplace version
output_indices
: list of integers
output_indices
Indices of the outputs to attempt to compute inplace
Indices of the outputs to attempt to compute inplace
alloc_ops : list of Op classes
Classes that represent operation that allocate new memory and
that the optimization should duplicate so it can operate inplace
on them.
"""
"""
op
=
node
.
op
op
:
Scan
=
cast
(
Scan
,
node
.
op
)
# inputs corresponding to sequences and n_steps
# inputs corresponding to sequences and n_steps
ls_begin
=
node
.
inputs
[:
1
+
op
.
info
.
n_seqs
]
ls_begin
=
node
.
inputs
[:
1
+
op
.
info
.
n_seqs
]
...
@@ -964,14 +964,14 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -964,14 +964,14 @@ class ScanInplaceOptimizer(GlobalOptimizer):
ls_end
+=
op
.
outer_nitsot
(
node
.
inputs
)
ls_end
+=
op
.
outer_nitsot
(
node
.
inputs
)
ls_end
+=
op
.
outer_non_seqs
(
node
.
inputs
)
ls_end
+=
op
.
outer_non_seqs
(
node
.
inputs
)
# In `ls`, duplicate any input which has more th
e
n one client and is
# In `ls`, duplicate any input which has more th
a
n one client and is
# the output of an eligible allocation op
# the output of an eligible allocation op
for
i
in
range
(
len
(
ls
)):
for
i
in
range
(
len
(
ls
)):
inp
=
ls
[
i
]
inp
=
ls
[
i
]
if
(
if
(
len
(
fgraph
.
clients
[
inp
])
>
1
len
(
fgraph
.
clients
[
inp
])
>
1
and
inp
.
owner
and
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
alloc_ops
)
and
isinstance
(
inp
.
owner
.
op
,
self
.
alloc_ops
)
):
):
new_lsi
=
inp
.
owner
.
op
.
make_node
(
*
inp
.
owner
.
inputs
)
new_lsi
=
inp
.
owner
.
op
.
make_node
(
*
inp
.
owner
.
inputs
)
...
@@ -991,23 +991,8 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -991,23 +991,8 @@ class ScanInplaceOptimizer(GlobalOptimizer):
ls
[
idx
]
=
deep_copy_op
(
ls
[
idx
])
ls
[
idx
]
=
deep_copy_op
(
ls
[
idx
])
inputs
=
ls_begin
+
ls
+
ls_end
inputs
=
ls_begin
+
ls
+
ls_end
if
self
.
typeInfer
is
None
:
typeConstructor
=
None
else
:
typeConstructor
=
self
.
typeInfer
(
node
)
new_op
=
Scan
(
new_op
=
op
.
clone
()
op
.
inner_inputs
,
op
.
inner_outputs
,
op
.
info
,
mode
=
op
.
mode
,
typeConstructor
=
typeConstructor
,
profile
=
op
.
profile
,
truncate_gradient
=
op
.
truncate_gradient
,
# TODO: This seems questionable
name
=
op
.
name
,
allow_gc
=
op
.
allow_gc
,
)
destroy_map
=
op
.
destroy_map
.
copy
()
destroy_map
=
op
.
destroy_map
.
copy
()
for
out_idx
in
output_indices
:
for
out_idx
in
output_indices
:
...
@@ -1016,9 +1001,16 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1016,9 +1001,16 @@ class ScanInplaceOptimizer(GlobalOptimizer):
new_op
.
destroy_map
=
destroy_map
new_op
.
destroy_map
=
destroy_map
# Do not call make_node for test_value
# Do not call make_node for test_value
new_outs
=
new_op
(
*
inputs
,
return_list
=
True
)
new_outs
:
List
[
Variable
]
=
new_op
(
*
inputs
,
return_list
=
True
)
try
:
try
:
fgraph
.
replace_all_validate_remove
(
# TODO FIXME: We need to stop using this approach (i.e. attempt
# in-place replacements and wait for downstream failures to revert
# the changes). It prevents us from making smart, clear
# rewrites and it adds a lot of unnecessary overhead that
# involves dealing with inconsistent graphs.
# This whole rewrite should be a simple local rewrite, but, because
# of this awful approach, it can't be.
fgraph
.
replace_all_validate_remove
(
# type: ignore
list
(
zip
(
node
.
outputs
,
new_outs
)),
list
(
zip
(
node
.
outputs
,
new_outs
)),
remove
=
[
node
],
remove
=
[
node
],
reason
=
"scan_make_inplace"
,
reason
=
"scan_make_inplace"
,
...
@@ -1026,20 +1018,19 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1026,20 +1018,19 @@ class ScanInplaceOptimizer(GlobalOptimizer):
return
new_outs
[
0
]
.
owner
return
new_outs
[
0
]
.
owner
except
InconsistencyError
:
except
InconsistencyError
:
# Failed moving output to be computed inplace
# Failed moving output to be computed inplace
return
nod
e
return
Non
e
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
alloc_ops
=
(
Alloc
,
AllocEmpty
)
for
scan_idx
,
original_node
in
enumerate
(
reversed
(
fgraph
.
toposort
())):
nodes
=
fgraph
.
toposort
()[::
-
1
]
scan_nodes
=
[
x
for
x
in
nodes
if
(
isinstance
(
x
.
op
,
Scan
))]
if
not
isinstance
(
original_node
.
op
,
Scan
):
for
scan_idx
in
range
(
len
(
scan_nodes
)):
continue
# First attempt to make the Scan compute inplace every recurrent
# First attempt to make the Scan compute inplace every recurrent
# output that seems like it could be computed inplace. If that
# output that seems like it could be computed inplace. If that
# fails, go through these outputs individually, trying each of
# fails, go through these outputs individually, trying each of
# them.
# them.
original_node
=
scan_nodes
[
scan_idx
]
op
=
original_node
.
op
op
=
original_node
.
op
n_outs
=
op
.
info
.
n_mit_mot
+
op
.
info
.
n_mit_sot
+
op
.
info
.
n_sit_sot
n_outs
=
op
.
info
.
n_mit_mot
+
op
.
info
.
n_mit_sot
+
op
.
info
.
n_sit_sot
...
@@ -1053,7 +1044,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1053,7 +1044,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# If the input is from an eligible allocation node, attempt to
# If the input is from an eligible allocation node, attempt to
# be inplace on it, even if other nodes are modifying it
# be inplace on it, even if other nodes are modifying it
# inplace.
# inplace.
if
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
alloc_ops
):
if
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
self
.
alloc_ops
):
out_indices
.
append
(
out_idx
)
out_indices
.
append
(
out_idx
)
continue
continue
...
@@ -1079,16 +1070,21 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1079,16 +1070,21 @@ class ScanInplaceOptimizer(GlobalOptimizer):
if
not
input_used_inplace
:
if
not
input_used_inplace
:
out_indices
.
append
(
out_idx
)
out_indices
.
append
(
out_idx
)
node
=
self
.
attempt_scan_inplace
(
if
len
(
out_indices
)
==
0
:
fgraph
,
scan_nodes
[
scan_idx
],
out_indices
,
alloc_ops
continue
)
if
node
is
original_node
:
new_node
=
self
.
attempt_scan_inplace
(
fgraph
,
original_node
,
out_indices
)
if
new_node
is
None
:
# Making the scan compute all plausible recurrent outputs
# Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output
# inplace has failed. Attempt all plausible recurrent output
s
# individually.
# individually.
new_node
=
original_node
for
pos
in
out_indices
:
for
pos
in
out_indices
:
node
=
self
.
attempt_scan_inplace
(
fgraph
,
node
,
[
pos
],
alloc_ops
)
new_node
=
(
self
.
attempt_scan_inplace
(
fgraph
,
new_node
,
[
pos
])
or
new_node
)
def
select_min
(
x
,
y
):
def
select_min
(
x
,
y
):
...
@@ -2367,7 +2363,7 @@ optdb.register(
...
@@ -2367,7 +2363,7 @@ optdb.register(
)
)
optdb
.
register
(
optdb
.
register
(
"scan_make_inplace"
,
"scan_make_inplace"
,
ScanInplaceOptimizer
(
typeInfer
=
None
),
ScanInplaceOptimizer
(),
"fast_run"
,
"fast_run"
,
"inplace"
,
"inplace"
,
"scan"
,
"scan"
,
...
...
tests/scan/test_opt.py
浏览文件 @
74fb5433
...
@@ -9,9 +9,10 @@ from aesara.compile.io import In
...
@@ -9,9 +9,10 @@ from aesara.compile.io import In
from
aesara.compile.mode
import
get_default_mode
from
aesara.compile.mode
import
get_default_mode
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.gradient
import
grad
,
jacobian
from
aesara.gradient
import
grad
,
jacobian
from
aesara.graph.basic
import
clone_replace
from
aesara.graph.basic
import
clone_replace
,
equal_computations
from
aesara.graph.fg
import
FunctionGraph
from
aesara.scan.op
import
Scan
from
aesara.scan.op
import
Scan
from
aesara.scan.opt
import
ScanMerge
from
aesara.scan.opt
import
Scan
InplaceOptimizer
,
Scan
Merge
from
aesara.scan.utils
import
until
from
aesara.scan.utils
import
until
from
aesara.tensor.blas
import
Dot22
from
aesara.tensor.blas
import
Dot22
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.elemwise
import
Elemwise
...
@@ -912,6 +913,49 @@ class TestScanMerge:
...
@@ -912,6 +913,49 @@ class TestScanMerge:
class
TestScanInplaceOptimizer
:
class
TestScanInplaceOptimizer
:
mode
=
get_default_mode
()
.
including
(
"scan_make_inplace"
,
"inplace"
)
mode
=
get_default_mode
()
.
including
(
"scan_make_inplace"
,
"inplace"
)
def
test_no_inplace
(
self
):
"""Make sure the rewrite doesn't make unnecessary replacements."""
x
=
at
.
vector
(
"x"
)
scan_out
,
_
=
aesara
.
scan
(
lambda
x
:
(
x
+
1
)
/
2
+
1
,
sequences
=
[
x
],
)
fgraph
=
FunctionGraph
(
outputs
=
[
scan_out
],
clone
=
True
,
copy_inputs
=
False
,
copy_orphans
=
False
)
_
=
ScanInplaceOptimizer
()
.
apply
(
fgraph
)
fgraph_op
=
fgraph
.
outputs
[
0
]
.
owner
.
inputs
[
0
]
.
owner
.
op
assert
not
fgraph_op
.
destroy_map
assert
equal_computations
([
scan_out
],
fgraph
.
outputs
)
def
test_inplace_basic
(
self
):
scan_out
,
_
=
aesara
.
scan
(
lambda
x
:
x
+
1
,
outputs_info
=
[
at
.
zeros
(
1
)],
n_steps
=
3
,
)
fgraph
=
FunctionGraph
(
outputs
=
[
scan_out
],
clone
=
True
,
copy_inputs
=
False
,
copy_orphans
=
False
)
assert
equal_computations
([
scan_out
],
fgraph
.
outputs
)
_
=
ScanInplaceOptimizer
()
.
apply
(
fgraph
)
# The graphs shouldn't change; only the `Op.destroy_map`s
assert
equal_computations
([
scan_out
],
fgraph
.
outputs
)
fgraph_op
=
fgraph
.
outputs
[
0
]
.
owner
.
inputs
[
0
]
.
owner
.
op
assert
fgraph_op
.
destroy_map
==
{
0
:
[
1
]}
assert
not
scan_out
.
owner
.
inputs
[
0
]
.
owner
.
op
.
destroy_map
@utt.assertFailure_fast
@utt.assertFailure_fast
def
test_simple_rnn
(
self
):
def
test_simple_rnn
(
self
):
"""Simple RNN; compute inplace version 1."""
"""Simple RNN; compute inplace version 1."""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论